From 6a175697339cca91ad6979b9776e15b130354906 Mon Sep 17 00:00:00 2001
From: Tristan Nixon
Date: Thu, 11 Aug 2022 16:35:49 -0700
Subject: [PATCH 01/11] created new TSIndex and TSSchema classes to represent
TSDF metadata. First round of TSDF code changes to use the new classes
---
python/tempo/tsdf.py | 245 ++++++++++++++++++++++-----------------
python/tempo/tsschema.py | 99 ++++++++++++++++
2 files changed, 236 insertions(+), 108 deletions(-)
create mode 100644 python/tempo/tsschema.py
diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py
index 6a515b33..b5c50d2a 100644
--- a/python/tempo/tsdf.py
+++ b/python/tempo/tsdf.py
@@ -3,7 +3,7 @@
import logging
import operator
from functools import reduce
-from typing import List, Union, Callable
+from typing import List, Union, Callable, Collection, Set
import numpy as np
import pyspark.sql.functions as f
@@ -18,6 +18,7 @@
import tempo.io as tio
import tempo.resample as rs
from tempo.interpol import Interpolation
+from tempo.tsschema import TSSchema
from tempo.utils import (
ENV_CAN_RENDER_HTML,
IS_DATABRICKS,
@@ -33,33 +34,84 @@ class TSDF:
This object is the main wrapper over a Spark data frame which allows a user to parallelize time series computations on a Spark data frame by various dimensions. The two dimensions required are partition_cols (list of columns by which to summarize) and ts_col (timestamp column, which can be epoch or TimestampType).
"""
- def __init__(self, df, ts_col="event_ts", partition_cols=None, sequence_col=None):
- """
- Constructor
- :param df:
- :param ts_col:
- :param partitionCols:
- :sequence_col every tsdf allows for a tie-breaker secondary sort key
- """
- self.ts_col = self.__validated_column(df, ts_col)
- self.partitionCols = (
- []
- if partition_cols is None
- else self.__validated_columns(df, partition_cols.copy())
- )
-
+ def __init__(
+ self,
+ df: DataFrame,
+ ts_schema: TSSchema = None,
+ ts_col: str = None,
+ series_ids: Collection[str] = None,
+ validate_schema=True,
+ ) -> None:
self.df = df
- self.sequence_col = "" if sequence_col is None else sequence_col
+ # construct schema if we don't already have one
+ if ts_schema:
+ self.ts_schema = ts_schema
+ else:
+ self.ts_schema = TSSchema.fromDFSchema(self.df.schema, ts_col, series_ids)
+ # validate that this schema works for this DataFrame
+ if validate_schema:
+ self.ts_schema.validate(df.schema)
+
+ @property
+ def ts_index(self) -> str:
+ return self.ts_schema.ts_index
+
+ @property
+ def ts_col(self) -> str:
+ if self.ts_schema.user_ts_col:
+ return self.ts_schema.user_ts_col
+ return self.ts_index
+
+ @property
+ def series_ids(self) -> List[str]:
+ return self.ts_schema.series_ids
+
+ @property
+ def structural_cols(self) -> Set[str]:
+ return self.ts_schema.structural_columns
+
+ @property
+ def observational_cols(self) -> List[str]:
+ return [
+ col.name
+ for col in self.ts_schema.find_observational_columns(self.df.schema)
+ ]
- # Add customized check for string type for the timestamp. If we see a string, we will proactively created a double version of the string timestamp for sorting purposes and rename to ts_col
- if df.schema[ts_col].dataType == "StringType":
- sample_ts = df.limit(1).collect()[0][0]
- self.__validate_ts_string(sample_ts)
- self.__add_double_ts().withColumnRenamed("double_ts", self.ts_col)
+ @property
+ def metric_cols(self) -> List[str]:
+ return [col.name for col in self.ts_schema.find_metric_columns(self.df.schema)]
- """
- Make sure DF is ordered by its respective ts_col and partition columns.
- """
+ #
+ # Class Factory Methods
+ #
+
+ @classmethod
+ def __withTransformedDF(cls, new_df: DataFrame, ts_schema: TSSchema) -> "TSDF":
+ return cls(new_df, ts_schema=ts_schema, validate_schema=False)
+
+ # def __init__(self, df, ts_col="event_ts", partition_cols=None, sequence_col=None):
+ # """
+ # Constructor
+ # :param df:
+ # :param ts_col:
+ # :param partitionCols:
+ # :sequence_col every tsdf allows for a tie-breaker secondary sort key
+ # """
+ # self.ts_col = self.__validated_column(df, ts_col)
+ # self.partitionCols = (
+ # []
+ # if partition_cols is None
+ # else self.__validated_columns(df, partition_cols.copy())
+ # )
+ #
+ # self.df = df
+ # self.sequence_col = "" if sequence_col is None else sequence_col
+ #
+ # # Add customized check for string type for the timestamp. If we see a string, we will proactively created a double version of the string timestamp for sorting purposes and rename to ts_col
+ # if df.schema[ts_col].dataType == "StringType":
+ # sample_ts = df.limit(1).collect()[0][0]
+ # self.__validate_ts_string(sample_ts)
+ # self.__add_double_ts().withColumnRenamed("double_ts", self.ts_col)
#
# Helper functions
@@ -119,7 +171,7 @@ def __validated_columns(self, df, colnames):
return colnames
def __checkPartitionCols(self, tsdf_right):
- for left_col, right_col in zip(self.partitionCols, tsdf_right.partitionCols):
+ for left_col, right_col in zip(self.series_ids, tsdf_right.series_ids):
if left_col != right_col:
raise ValueError(
"left and right dataframe partition columns should have same name in same order"
@@ -158,7 +210,7 @@ def __addPrefixToColumns(self, col_list, prefix):
if self.sequence_col
else self.sequence_col
)
- return TSDF(df, ts_col, self.partitionCols, sequence_col=seq_col)
+ return TSDF(df, ts_col, self.series_ids, sequence_col=seq_col)
def __addColumnsFromOtherDF(self, other_cols):
"""
@@ -170,14 +222,14 @@ def __addColumnsFromOtherDF(self, other_cols):
self.df,
)
- return TSDF(new_df, self.ts_col, self.partitionCols)
+ return TSDF(new_df, self.ts_col, self.series_ids)
def __combineTSDF(self, ts_df_right, combined_ts_col):
combined_df = self.df.unionByName(ts_df_right.df).withColumn(
combined_ts_col, f.coalesce(self.ts_col, ts_df_right.ts_col)
)
- return TSDF(combined_df, combined_ts_col, self.partitionCols)
+ return TSDF(combined_df, combined_ts_col, self.series_ids)
def __getLastRightRow(
self,
@@ -197,7 +249,7 @@ def __getLastRightRow(
sort_keys = [f.col(col_name) for col_name in ptntl_sort_keys if col_name != ""]
window_spec = (
- Window.partitionBy(self.partitionCols)
+ Window.partitionBy(self.series_ids)
.orderBy(sort_keys)
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
)
@@ -275,7 +327,7 @@ def __getLastRightRow(
)
df = df.drop(column)
- return TSDF(df, left_ts_col, self.partitionCols)
+ return TSDF(df, left_ts_col, self.series_ids)
def __getTimePartitions(self, tsPartitionVal, fraction=0.1):
"""
@@ -316,7 +368,7 @@ def __getTimePartitions(self, tsPartitionVal, fraction=0.1):
df = partition_df.union(remainder_df).drop(
"partition_remainder", "ts_col_double"
)
- return TSDF(df, self.ts_col, self.partitionCols + ["ts_partition"])
+ return TSDF(df, self.ts_col, self.series_ids + ["ts_partition"])
#
# Slicing & Selection
@@ -342,12 +394,12 @@ def select(self, *cols):
"""
# The columns which will be a mandatory requirement while selecting from TSDFs
seq_col_stub = [] if bool(self.sequence_col) is False else [self.sequence_col]
- mandatory_cols = [self.ts_col] + self.partitionCols + seq_col_stub
+ mandatory_cols = [self.ts_col] + self.series_ids + seq_col_stub
if set(mandatory_cols).issubset(set(cols)):
return TSDF(
self.df.select(*cols),
self.ts_col,
- self.partitionCols,
+ self.series_ids,
self.sequence_col,
)
else:
@@ -369,12 +421,7 @@ def __slice(self, op: str, target_ts):
target_expr = f"'{target_ts}'" if isinstance(target_ts, str) else target_ts
slice_expr = f.expr(f"{self.ts_col} {op} {target_expr}")
sliced_df = self.df.where(slice_expr)
- return TSDF(
- sliced_df,
- ts_col=self.ts_col,
- partition_cols=self.partitionCols,
- sequence_col=self.sequence_col,
- )
+ return TSDF.__withTransformedDF(sliced_df, self.ts_schema)
def at(self, ts):
"""
@@ -456,12 +503,7 @@ def __top_rows_per_series(self, win: WindowSpec, n: int):
.where(f.col(row_num_col) <= f.lit(n))
.drop(row_num_col)
)
- return TSDF(
- prev_records_df,
- ts_col=self.ts_col,
- partition_cols=self.partitionCols,
- sequence_col=self.sequence_col,
- )
+ return TSDF.__withTransformedDF(prev_records_df, self.ts_schema)
def earliest(self, n: int = 1):
"""
@@ -579,7 +621,7 @@ def describe(self):
# describe stats
desc_stats = this_df.describe().union(missing_vals)
- unique_ts = this_df.select(*self.partitionCols).distinct().count()
+ unique_ts = this_df.select(*self.series_ids).distinct().count()
max_ts = this_df.select(f.max(f.col(self.ts_col)).alias("max_ts")).collect()[0][
0
@@ -707,10 +749,10 @@ def asofJoin(
(left_bytes < bytes_threshold) | (right_bytes < bytes_threshold)
):
spark.conf.set("spark.databricks.optimizer.rangeJoin.binSize", 60)
- partition_cols = right_tsdf.partitionCols
- left_cols = list(set(left_df.columns).difference(set(self.partitionCols)))
+ partition_cols = right_tsdf.series_ids
+ left_cols = list(set(left_df.columns).difference(set(self.series_ids)))
right_cols = list(
- set(right_df.columns).difference(set(right_tsdf.partitionCols))
+ set(right_df.columns).difference(set(right_tsdf.series_ids))
)
left_prefix = (
@@ -753,7 +795,7 @@ def asofJoin(
)
.drop("lead_" + right_tsdf.ts_col)
)
- return TSDF(res, partition_cols=self.partitionCols, ts_col=new_left_ts_col)
+ return TSDF(res, series_ids=self.series_ids, ts_col=new_left_ts_col)
# end of block checking to see if standard Spark SQL join will work
@@ -772,11 +814,9 @@ def asofJoin(
# validate timestamp datatypes match
self.__validateTsColMatch(right_tsdf)
- orig_left_col_diff = list(
- set(left_df.columns).difference(set(self.partitionCols))
- )
+ orig_left_col_diff = list(set(left_df.columns).difference(set(self.series_ids)))
orig_right_col_diff = list(
- set(right_df.columns).difference(set(self.partitionCols))
+ set(right_df.columns).difference(set(self.series_ids))
)
left_tsdf = (
@@ -789,10 +829,10 @@ def asofJoin(
)
left_nonpartition_cols = list(
- set(left_tsdf.df.columns).difference(set(self.partitionCols))
+ set(left_tsdf.df.columns).difference(set(self.series_ids))
)
right_nonpartition_cols = list(
- set(right_tsdf.df.columns).difference(set(self.partitionCols))
+ set(right_tsdf.df.columns).difference(set(self.series_ids))
)
# For both dataframes get all non-partition columns (including ts_col)
@@ -836,29 +876,24 @@ def asofJoin(
"ts_partition", "is_original"
)
- asofDF = TSDF(df, asofDF.ts_col, combined_df.partitionCols)
+ asofDF = TSDF(df, asofDF.ts_col, combined_df.series_ids)
return asofDF
def __baseWindow(self, sort_col=None, reverse=False):
- # figure out our sorting columns
- primary_sort_col = self.ts_col if not sort_col else sort_col
- sort_cols = (
- [primary_sort_col, self.sequence_col]
- if self.sequence_col
- else [primary_sort_col]
- )
-
# are we ordering forwards (default) or reveresed?
col_fn = f.col
if reverse:
col_fn = lambda colname: f.col(colname).desc() # noqa E731
# our window will be sorted on our sort_cols in the appropriate direction
- w = Window().orderBy([col_fn(col) for col in sort_cols])
+ if reverse:
+ w = Window().orderBy(f.col(self.ts_index).desc())
+ else:
+ w = Window().orderBy(f.col(self.ts_index).asc())
# and partitioned by any series IDs
- if self.partitionCols:
- w = w.partitionBy([f.col(elem) for elem in self.partitionCols])
+ if self.series_ids:
+ w = w.partitionBy([f.col(sid) for sid in self.series_ids])
return w
def __rangeBetweenWindow(self, range_from, range_to, sort_col=None, reverse=False):
@@ -900,8 +935,8 @@ def vwap(self, frequency="m", volume_col="volume", price_col="price"):
)
group_cols = ["time_group"]
- if self.partitionCols:
- group_cols.extend(self.partitionCols)
+ if self.series_ids:
+ group_cols.extend(self.series_ids)
vwapped = (
pre_vwap.withColumn("dllr_value", f.col(price_col) * f.col(volume_col))
.groupby(group_cols)
@@ -913,7 +948,7 @@ def vwap(self, frequency="m", volume_col="volume", price_col="price"):
.withColumn("vwap", f.col("dllr_value") / f.col(volume_col))
)
- return TSDF(vwapped, self.ts_col, self.partitionCols)
+ return TSDF(vwapped, self.ts_col, self.series_ids)
def EMA(self, colName, window=30, exp_factor=0.2):
"""
@@ -940,7 +975,7 @@ def EMA(self, colName, window=30, exp_factor=0.2):
).drop(lagColName)
# Nulls are currently removed
- return TSDF(df, self.ts_col, self.partitionCols)
+ return TSDF(df, self.ts_col, self.series_ids)
def withLookbackFeatures(
self, featureCols, lookbackWindowSize, exactSize=True, featureColName="features"
@@ -974,7 +1009,7 @@ def withLookbackFeatures(
if exactSize:
return lookback_tsdf.where(f.size(featureColName) == lookbackWindowSize)
- return TSDF(lookback_tsdf, self.ts_col, self.partitionCols)
+ return TSDF(lookback_tsdf, self.ts_col, self.series_ids)
def withRangeStats(
self, type="range", colsToSummarize=[], rangeBackWindowSecs=1000
@@ -1000,8 +1035,8 @@ def withRangeStats(
if not colsToSummarize:
# columns we should never summarize
prohibited_cols = [self.ts_col.lower()]
- if self.partitionCols:
- prohibited_cols.extend([pc.lower() for pc in self.partitionCols])
+ if self.series_ids:
+ prohibited_cols.extend([pc.lower() for pc in self.series_ids])
# types that can be summarized
summarizable_types = ["int", "bigint", "float", "double"]
# filter columns to find summarizable columns
@@ -1045,7 +1080,7 @@ def withRangeStats(
"double_ts"
)
- return TSDF(summary_df, self.ts_col, self.partitionCols)
+ return TSDF(summary_df, self.ts_col, self.series_ids)
def withGroupedStats(self, metricCols=[], freq=None):
"""
@@ -1062,8 +1097,8 @@ def withGroupedStats(self, metricCols=[], freq=None):
if not metricCols:
# columns we should never summarize
prohibited_cols = [self.ts_col.lower()]
- if self.partitionCols:
- prohibited_cols.extend([pc.lower() for pc in self.partitionCols])
+ if self.series_ids:
+ prohibited_cols.extend([pc.lower() for pc in self.series_ids])
# types that can be summarized
summarizable_types = ["int", "bigint", "float", "double"]
# filter columns to find summarizable columns
@@ -1097,16 +1132,14 @@ def withGroupedStats(self, metricCols=[], freq=None):
]
)
- selected_df = self.df.groupBy(self.partitionCols + [agg_window]).agg(
- *selectedCols
- )
+ selected_df = self.df.groupBy(self.series_ids + [agg_window]).agg(*selectedCols)
summary_df = (
selected_df.select(*selected_df.columns)
.withColumn(self.ts_col, f.col("window").start)
.drop("window")
)
- return TSDF(summary_df, self.ts_col, self.partitionCols)
+ return TSDF(summary_df, self.ts_col, self.series_ids)
def write(self, spark, tabName, optimizationCols=None):
tio.write(self, spark, tabName, optimizationCols)
@@ -1134,7 +1167,7 @@ def resample(
# Throw warning for user to validate that the expected number of output rows is valid.
if fill is True and perform_checks is True:
- calculate_time_horizon(self.df, self.ts_col, freq, self.partitionCols)
+ calculate_time_horizon(self.df, self.ts_col, freq, self.series_ids)
enriched_df: DataFrame = rs.aggregate(
self, freq, func, metricCols, prefix, fill
@@ -1142,7 +1175,7 @@ def resample(
return _ResampledTSDF(
enriched_df,
ts_col=self.ts_col,
- partition_cols=self.partitionCols,
+ series_ids=self.series_ids,
freq=freq,
func=func,
)
@@ -1177,7 +1210,7 @@ def interpolate(
if ts_col is None:
ts_col = self.ts_col
if partition_cols is None:
- partition_cols = self.partitionCols
+ partition_cols = self.series_ids
if target_cols is None:
prohibited_cols: List[str] = partition_cols + [ts_col]
summarizable_types = ["int", "bigint", "float", "double"]
@@ -1193,7 +1226,7 @@ def interpolate(
]
interpolate_service: Interpolation = Interpolation(is_resampled=False)
- tsdf_input = TSDF(self.df, ts_col=ts_col, partition_cols=partition_cols)
+ tsdf_input = TSDF(self.df, ts_col=ts_col, series_ids=partition_cols)
interpolated_df: DataFrame = interpolate_service.interpolate(
tsdf_input,
ts_col,
@@ -1206,7 +1239,7 @@ def interpolate(
perform_checks,
)
- return TSDF(interpolated_df, ts_col=ts_col, partition_cols=partition_cols)
+ return TSDF(interpolated_df, ts_col=ts_col, series_ids=partition_cols)
def calc_bars(tsdf, freq, func=None, metricCols=None, fill=None):
@@ -1223,21 +1256,21 @@ def calc_bars(tsdf, freq, func=None, metricCols=None, fill=None):
freq=freq, func="ceil", metricCols=metricCols, prefix="close", fill=fill
)
- join_cols = resample_open.partitionCols + [resample_open.ts_col]
+ join_cols = resample_open.series_ids + [resample_open.ts_col]
bars = (
resample_open.df.join(resample_high.df, join_cols)
.join(resample_low.df, join_cols)
.join(resample_close.df, join_cols)
)
- non_part_cols = set(set(bars.columns) - set(resample_open.partitionCols)) - set(
+ non_part_cols = set(set(bars.columns) - set(resample_open.series_ids)) - set(
[resample_open.ts_col]
)
sel_and_sort = (
- resample_open.partitionCols + [resample_open.ts_col] + sorted(non_part_cols)
+ resample_open.series_ids + [resample_open.ts_col] + sorted(non_part_cols)
)
bars = bars.select(sel_and_sort)
- return TSDF(bars, resample_open.ts_col, resample_open.partitionCols)
+ return TSDF(bars, resample_open.ts_col, resample_open.series_ids)
def fourier_transform(self, timestep, valueCol):
"""
@@ -1267,7 +1300,7 @@ def tempo_fourier_util(pdf):
valueCol = self.__validated_column(self.df, valueCol)
data = self.df
if self.sequence_col:
- if self.partitionCols == []:
+ if self.series_ids == []:
data = data.withColumn("dummy_group", f.lit("dummy_val"))
data = (
data.select(
@@ -1288,7 +1321,7 @@ def tempo_fourier_util(pdf):
)
result = result.drop("dummy_group", "tdval", "tpoints")
else:
- group_cols = self.partitionCols
+ group_cols = self.series_ids
data = (
data.select(
*group_cols, self.ts_col, self.sequence_col, f.col(valueCol)
@@ -1305,7 +1338,7 @@ def tempo_fourier_util(pdf):
)
result = result.drop("tdval", "tpoints")
else:
- if self.partitionCols == []:
+ if self.series_ids == []:
data = data.withColumn("dummy_group", f.lit("dummy_val"))
data = (
data.select(f.col("dummy_group"), self.ts_col, f.col(valueCol))
@@ -1321,7 +1354,7 @@ def tempo_fourier_util(pdf):
)
result = result.drop("dummy_group", "tdval", "tpoints")
else:
- group_cols = self.partitionCols
+ group_cols = self.series_ids
data = (
data.select(*group_cols, self.ts_col, f.col(valueCol))
.withColumn("tdval", f.col(valueCol))
@@ -1336,7 +1369,7 @@ def tempo_fourier_util(pdf):
)
result = result.drop("tdval", "tpoints")
- return TSDF(result, self.ts_col, self.partitionCols, self.sequence_col)
+ return TSDF(result, self.ts_col, self.series_ids, self.sequence_col)
def extractStateIntervals(
self,
@@ -1447,7 +1480,7 @@ def state_comparison_fn(a, b):
# Find the start and end timestamp of the interval
result = (
- data.groupBy(*self.partitionCols, "state_incrementer")
+ data.groupBy(*self.series_ids, "state_incrementer")
.agg(
f.min("previous_ts").alias("start_ts"),
f.max(self.ts_col).alias("end_ts"),
@@ -1463,12 +1496,12 @@ def __init__(
self,
df,
ts_col="event_ts",
- partition_cols=None,
+ series_ids=None,
sequence_col=None,
freq=None,
func=None,
):
- super(_ResampledTSDF, self).__init__(df, ts_col, partition_cols, sequence_col)
+ super(_ResampledTSDF, self).__init__(df, ts_col, series_ids, sequence_col)
self.__freq = freq
self.__func = func
@@ -1491,7 +1524,7 @@ def interpolate(
# Set defaults for target columns, timestamp column and partition columns when not provided
if target_cols is None:
- prohibited_cols: List[str] = self.partitionCols + [self.ts_col]
+ prohibited_cols: List[str] = self.series_ids + [self.ts_col]
summarizable_types = ["int", "bigint", "float", "double"]
# get summarizable find summarizable columns
@@ -1505,13 +1538,11 @@ def interpolate(
]
interpolate_service: Interpolation = Interpolation(is_resampled=True)
- tsdf_input = TSDF(
- self.df, ts_col=self.ts_col, partition_cols=self.partitionCols
- )
+ tsdf_input = TSDF(self.df, ts_col=self.ts_col, series_ids=self.series_ids)
interpolated_df = interpolate_service.interpolate(
tsdf=tsdf_input,
ts_col=self.ts_col,
- partition_cols=self.partitionCols,
+ series_ids=self.series_ids,
target_cols=target_cols,
freq=self.__freq,
func=self.__func,
@@ -1520,6 +1551,4 @@ def interpolate(
perform_checks=perform_checks,
)
- return TSDF(
- interpolated_df, ts_col=self.ts_col, partition_cols=self.partitionCols
- )
+ return TSDF(interpolated_df, ts_col=self.ts_col, series_ids=self.series_ids)
diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py
new file mode 100644
index 00000000..457e0884
--- /dev/null
+++ b/python/tempo/tsschema.py
@@ -0,0 +1,99 @@
+from typing import Collection
+
+from pyspark.sql.types import *
+
+
+class TSIndex:
+ # Valid types for time index columns
+ __valid_ts_types = (
+ DateType(),
+ TimestampType(),
+ ByteType(),
+ ShortType(),
+ IntegerType(),
+ LongType(),
+ DecimalType(),
+ FloatType(),
+ DoubleType(),
+ )
+
+ def __init__(self, name: str, dataType: DataType) -> None:
+ if dataType not in self.__valid_ts_types:
+ raise TypeError(f"DataType {dataType} is not valid for a Timeseries Index")
+ self.name = name
+ self.dataType = dataType
+
+ @classmethod
+ def fromField(cls, ts_field: StructField) -> "TSIndex":
+ return cls(ts_field.name, ts_field.dataType)
+
+
+class TSSchema:
+ """
+ Schema type for a :class:`TSDF` class.
+ """
+
+ # Valid types for metric columns
+ __metric_types = (
+ BooleanType(),
+ ByteType(),
+ ShortType(),
+ IntegerType(),
+ LongType(),
+ DecimalType(),
+ FloatType(),
+ DoubleType(),
+ )
+
+ def __init__(
+ self,
+ ts_idx: TSIndex,
+ series_ids: Collection[str] = None,
+ user_ts_col: str = None,
+ subsequence_col: str = None,
+ ) -> None:
+ self.ts_idx = ts_idx
+ self.series_ids = list(series_ids)
+ self.user_ts_col = user_ts_col
+ self.subsequence_col = subsequence_col
+
+ @classmethod
+ def fromDFSchema(
+ cls, df_schema: StructType, ts_col: str, series_ids: Collection[str] = None
+ ) -> "TSSchema":
+ # construct a TSIndex for the given ts_col
+ ts_idx = TSIndex.fromField(df_schema[ts_col])
+ return cls(ts_idx, series_ids)
+
+ @property
+ def ts_index(self) -> str:
+ return self.ts_idx.name
+
+ @property
+ def structural_columns(self) -> set[str]:
+ """
+ Structural columns are those that define the structure of the :class:`TSDF`. This includes the timeseries column,
+ a timeseries index (if different), any subsequence column (if present), and the series ID columns.
+
+ :return: a set of column names corresponding the structural columns of a :class:`TSDF`
+ """
+ struct_cols = {self.ts_index, self.user_ts_col, self.subsequence_col}.union(
+ self.series_ids
+ )
+ struct_cols.discard(None)
+ return struct_cols
+
+ def validate(self, df_schema: StructType) -> None:
+ pass
+
+ def find_observational_columns(self, df_schema: StructType) -> list[StructField]:
+ return [
+ col for col in df_schema.fields if col.name not in self.structural_columns
+ ]
+
+ def find_metric_columns(self, df_schema: StructType) -> list[StructField]:
+ return [
+ col
+ for col in self.find_observational_columns(df_schema)
+ if col.dataType in self.__metric_types
+ ]
From 3c3e5f85e18f2887e794fd72b4aa58b8678ea418 Mon Sep 17 00:00:00 2001
From: Tristan Nixon
Date: Mon, 29 Aug 2022 10:24:46 -0700
Subject: [PATCH 02/11] saving progess to this point
---
python/requirements.txt | 1 -
python/tempo/io.py | 4 +-
python/tempo/tsdf.py | 42 ++--
python/tempo/tsschema.py | 110 ++++++--
python/tests/base.py | 7 +-
python/tests/tsdf_tests.py | 2 +-
.../unit_test_data/as_of_join_tests.json | 32 +--
.../unit_test_data/delta_writer_tests.json | 2 +-
.../tests/unit_test_data/interpol_tests.json | 26 +-
python/tests/unit_test_data/tsdf_tests.json | 236 +++++-------------
python/tests/unit_test_data/utils_tests.json | 2 +-
11 files changed, 208 insertions(+), 256 deletions(-)
diff --git a/python/requirements.txt b/python/requirements.txt
index 0c61dab8..0a850c60 100644
--- a/python/requirements.txt
+++ b/python/requirements.txt
@@ -19,4 +19,3 @@ Sphinx==4.5.0
sphinx-design==0.2.0
sphinx-panels==0.6.0
jsonref==0.2
-python-dateutil==2.8.2
diff --git a/python/tempo/io.py b/python/tempo/io.py
index dedbc024..e0cc7f53 100644
--- a/python/tempo/io.py
+++ b/python/tempo/io.py
@@ -18,7 +18,7 @@ def write(tsdf, spark, tabName, optimizationCols=None):
df = tsdf.df
ts_col = tsdf.ts_col
- partitionCols = tsdf.partitionCols
+ series_ids = tsdf.series_ids
if optimizationCols:
optimizationCols = optimizationCols + ["event_time"]
else:
@@ -44,7 +44,7 @@ def write(tsdf, spark, tabName, optimizationCols=None):
try:
spark.sql(
"optimize {} zorder by {}".format(
- tabName, "(" + ",".join(partitionCols + optimizationCols) + ")"
+ tabName, "(" + ",".join(series_ids + optimizationCols) + ")"
)
)
except Exception as e:
diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py
index b5c50d2a..f124b031 100644
--- a/python/tempo/tsdf.py
+++ b/python/tempo/tsdf.py
@@ -52,15 +52,28 @@ def __init__(
if validate_schema:
self.ts_schema.validate(df.schema)
+ def __withTransformedDF(self, new_df: DataFrame) -> "TSDF":
+ return TSDF(new_df, ts_schema=self.ts_schema, validate_schema=False)
+
+ @classmethod
+ def fromSubsequenceCol(cls, df: DataFrame, ts_col: str, subsequence_col: str, series_ids: Collection[str] = None) -> "TSDF":
+ pass
+
+ @classmethod
+ def fromTimestampString(cls, df: DataFrame, ts_col: str, series_ids: Collection[str] = None, ts_fmt: str = "YYYY-MM-DDThh:mm:ss[.SSSSSS]") -> "TSDF":
+ pass
+
+ @classmethod
+ def fromDateString(cls, df: DataFrame, ts_col: str, series_ids: Collection[str], date_fmt: str = "YYYY-MM-DD") -> "TSDF ":
+ pass
+
@property
def ts_index(self) -> str:
return self.ts_schema.ts_index
@property
def ts_col(self) -> str:
- if self.ts_schema.user_ts_col:
- return self.ts_schema.user_ts_col
- return self.ts_index
+ return self.ts_index.name
@property
def series_ids(self) -> List[str]:
@@ -81,14 +94,6 @@ def observational_cols(self) -> List[str]:
def metric_cols(self) -> List[str]:
return [col.name for col in self.ts_schema.find_metric_columns(self.df.schema)]
- #
- # Class Factory Methods
- #
-
- @classmethod
- def __withTransformedDF(cls, new_df: DataFrame, ts_schema: TSSchema) -> "TSDF":
- return cls(new_df, ts_schema=ts_schema, validate_schema=False)
-
# def __init__(self, df, ts_col="event_ts", partition_cols=None, sequence_col=None):
# """
# Constructor
@@ -421,7 +426,7 @@ def __slice(self, op: str, target_ts):
target_expr = f"'{target_ts}'" if isinstance(target_ts, str) else target_ts
slice_expr = f.expr(f"{self.ts_col} {op} {target_expr}")
sliced_df = self.df.where(slice_expr)
- return TSDF.__withTransformedDF(sliced_df, self.ts_schema)
+ return self.__withTransformedDF(sliced_df)
def at(self, ts):
"""
@@ -503,7 +508,7 @@ def __top_rows_per_series(self, win: WindowSpec, n: int):
.where(f.col(row_num_col) <= f.lit(n))
.drop(row_num_col)
)
- return TSDF.__withTransformedDF(prev_records_df, self.ts_schema)
+ return self.__withTransformedDF(prev_records_df)
def earliest(self, n: int = 1):
"""
@@ -904,15 +909,6 @@ def __rangeBetweenWindow(self, range_from, range_to, sort_col=None, reverse=Fals
def __rowsBetweenWindow(self, rows_from, rows_to, reverse=False):
return self.__baseWindow(reverse=reverse).rowsBetween(rows_from, rows_to)
- def withPartitionCols(self, partitionCols):
- """
- Sets certain columns of the TSDF as partition columns. Partition columns are those that differentiate distinct timeseries
- from each other.
- :param partitionCols: a list of columns used to partition distinct timeseries
- :return: a TSDF object with the given partition columns
- """
- return TSDF(self.df, self.ts_col, partitionCols)
-
def vwap(self, frequency="m", volume_col="volume", price_col="price"):
# set pre_vwap as self or enrich with the frequency
pre_vwap = self.df
@@ -1051,7 +1047,7 @@ def withRangeStats(
# build window
if str(self.df.schema[self.ts_col].dataType) == "TimestampType":
- self.__add_double_ts()
+ self. __add_double_ts()
prohibited_cols.extend(["double_ts"])
w = self.__rangeBetweenWindow(
-1 * rangeBackWindowSecs, 0, sort_col="double_ts"
diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py
index 457e0884..4e425820 100644
--- a/python/tempo/tsschema.py
+++ b/python/tempo/tsschema.py
@@ -1,9 +1,16 @@
-from typing import Collection
+from abc import ABC, abstractmethod
+from typing import Union, Collection, List
+from pyspark.sql import Column
+import pyspark.sql.functions as Fn
from pyspark.sql.types import *
-class TSIndex:
+class TSIndex(ABC):
+ """
+ Abstract base class for all Timeseries Index types
+ """
+
# Valid types for time index columns
__valid_ts_types = (
DateType(),
@@ -18,14 +25,70 @@ class TSIndex:
)
def __init__(self, name: str, dataType: DataType) -> None:
- if dataType not in self.__valid_ts_types:
- raise TypeError(f"DataType {dataType} is not valid for a Timeseries Index")
self.name = name
self.dataType = dataType
- @classmethod
- def fromField(cls, ts_field: StructField) -> "TSIndex":
- return cls(ts_field.name, ts_field.dataType)
+ @abstractmethod
+ def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
+ """
+ Returns a :class:`Column` expression that will order the :class:`TSDF` according to the timeseries index.
+
+ :param reverse: whether or not the ordering should be reversed (backwards in time)
+
+ :return: an expression appropriate for ordering the :class:`TSDF` according to this index
+ """
+ pass
+
+class SimpleTSIndex(TSIndex):
+ """
+ Timeseries index based on a single column of a numeric or temporal type.
+ """
+
+ def __init__(self, ts_col: StructField) -> None:
+ if ts_col.dataType not in self.__valid_ts_types:
+ raise TypeError(f"DataType {ts_col.dataType} of column {ts_col.name} is not valid for a timeseries Index")
+ super().__init__(ts_col.name, ts_col.dataType)
+
+ def orderByExpr(self, reverse: bool = False) -> Column:
+ expr = Fn.col(self.name)
+ if reverse:
+ return expr.desc()
+ return expr
+
+
+class SubSequenceTSIndex(TSIndex):
+ """
+ Special timeseries index for columns that involve a secondary sequencing column
+ """
+
+ # default name for our timeseries index
+ __ts_idx_name = "ts_index"
+ # Valid types for sub-sequence columns
+ __valid_subseq_types = (
+ ByteType(),
+ ShortType(),
+ IntegerType(),
+ LongType()
+ )
+
+ def __init__(self, primary_ts_col: StructField, subsequence_col: StructField) -> None:
+ # validate these column types
+ if primary_ts_col.dataType not in self.__valid_ts_types:
+ raise TypeError(f"DataType {primary_ts_col.dataType} of column {primary_ts_col.name} is not valid for a timeseries Index")
+ if subsequence_col.dataType not in self.__valid_subseq_types:
+ raise TypeError(f"DataType {subsequence_col.dataType} of column {subsequence_col.name} is not valid for a sub-sequencing column")
+ # construct a struct for these
+ ts_struct = StructType([primary_ts_col, subsequence_col])
+ super().__init__(self.__ts_idx_name, ts_struct)
+ # set colnames for primary & subsequence
+ self.primary_ts_col = primary_ts_col.name
+ self.subsequence_col = subsequence_col.name
+
+ def orderByExpr(self, reverse: bool = False) -> List[Column]:
+ expr = [ Fn.col(self.primary_ts_col), Fn.col(self.subsequence_col) ]
+ if reverse:
+ return [col.desc() for col in expr]
+ return expr
class TSSchema:
@@ -48,21 +111,20 @@ class TSSchema:
def __init__(
self,
ts_idx: TSIndex,
- series_ids: Collection[str] = None,
- user_ts_col: str = None,
- subsequence_col: str = None,
+ series_ids: Collection[str] = None
) -> None:
self.ts_idx = ts_idx
- self.series_ids = list(series_ids)
- self.user_ts_col = user_ts_col
- self.subsequence_col = subsequence_col
+ if series_ids:
+ self.series_ids = list(series_ids)
+ else:
+ self.series_ids = None
@classmethod
def fromDFSchema(
cls, df_schema: StructType, ts_col: str, series_ids: Collection[str] = None
) -> "TSSchema":
# construct a TSIndex for the given ts_col
- ts_idx = TSIndex.fromField(df_schema[ts_col])
+ ts_idx = SimpleTSIndex(df_schema[ts_col])
return cls(ts_idx, series_ids)
@property
@@ -77,23 +139,21 @@ def structural_columns(self) -> set[str]:
:return: a set of column names corresponding the structural columns of a :class:`TSDF`
"""
- struct_cols = {self.ts_index, self.user_ts_col, self.subsequence_col}.union(
- self.series_ids
- )
+ struct_cols = {self.ts_index}.union(self.series_ids)
struct_cols.discard(None)
return struct_cols
def validate(self, df_schema: StructType) -> None:
pass
- def find_observational_columns(self, df_schema: StructType) -> list[StructField]:
- return [
- col for col in df_schema.fields if col.name not in self.structural_columns
- ]
+ def find_observational_columns(self, df_schema: StructType) -> set[str]:
+ return set(df_schema.fieldNames()) - self.structural_columns
- def find_metric_columns(self, df_schema: StructType) -> list[StructField]:
+ def find_metric_columns(self, df_schema: StructType) -> list[str]:
return [
- col
- for col in self.find_observational_columns(df_schema)
- if col.dataType in self.__metric_types
+ col.name
+ for col in df_schema.fields
+ if (col.dataType in self.__metric_types)
+ and
+ (col.name in self.find_observational_columns(df_schema))
]
diff --git a/python/tests/base.py b/python/tests/base.py
index eef48f59..550515c6 100644
--- a/python/tests/base.py
+++ b/python/tests/base.py
@@ -71,12 +71,7 @@ def get_data_as_sdf(self, name: str, convert_ts_col=True):
def get_data_as_tsdf(self, name: str, convert_ts_col=True):
df = self.get_data_as_sdf(name, convert_ts_col)
td = self.test_data[name]
- tsdf = TSDF(
- df,
- ts_col=td["ts_col"],
- partition_cols=td.get("partition_cols", None),
- sequence_col=td.get("sequence_col", None),
- )
+ tsdf = TSDF(df, ts_col=td["ts_col"], series_ids=td.get("series_ids", None))
return tsdf
TEST_DATA_FOLDER = "unit_test_data"
diff --git a/python/tests/tsdf_tests.py b/python/tests/tsdf_tests.py
index 5fae468f..e1326121 100644
--- a/python/tests/tsdf_tests.py
+++ b/python/tests/tsdf_tests.py
@@ -49,7 +49,7 @@ def __tsdf_with_double_tscol(self, tsdf: TSDF) -> TSDF:
with_double_tscol_df = tsdf.df.withColumn(
tsdf.ts_col, F.col(tsdf.ts_col).cast("double")
)
- return TSDF(with_double_tscol_df, tsdf.ts_col, tsdf.partitionCols)
+ return TSDF(with_double_tscol_df, ts_col=tsdf.ts_col, series_ids=tsdf.series_ids)
def test_at(self):
"""
diff --git a/python/tests/unit_test_data/as_of_join_tests.json b/python/tests/unit_test_data/as_of_join_tests.json
index d788d60a..879ed220 100644
--- a/python/tests/unit_test_data/as_of_join_tests.json
+++ b/python/tests/unit_test_data/as_of_join_tests.json
@@ -3,7 +3,7 @@
"shared_left": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
- "partition_cols": ["symbol"],
+ "series_ids": ["symbol"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21],
["S1", "2020-08-01 00:01:12", 351.32],
@@ -26,7 +26,7 @@
"right": {
"schema": "symbol string, event_ts string, bid_pr float, ask_pr float",
"ts_col": "event_ts",
- "partition_cols": ["symbol"],
+ "series_ids": ["symbol"],
"data": [
["S1", "2020-08-01 00:00:01", 345.11, 351.12],
["S1", "2020-08-01 00:01:05", 348.10, 353.13],
@@ -37,7 +37,7 @@
"expected": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_bid_pr float, right_ask_pr float",
"ts_col": "left_event_ts",
- "partition_cols": ["symbol"],
+ "series_ids": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": {
"$ref": "#/__SharedData/test_asof_expected_data"
@@ -46,7 +46,7 @@
"expected_no_right_prefix": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, event_ts string, bid_pr float, ask_pr float",
"ts_col": "left_event_ts",
- "partition_cols": ["symbol"],
+ "series_ids": ["symbol"],
"other_ts_cols": ["event_ts"],
"data": {
"$ref": "#/__SharedData/test_asof_expected_data"
@@ -60,7 +60,7 @@
"right": {
"schema": "symbol string, event_ts string, bid_pr float, ask_pr float",
"ts_col": "event_ts",
- "partition_cols": ["symbol"],
+ "series_ids": ["symbol"],
"data": [
["S1", "2020-08-01 00:00:01", 345.11, 351.12],
["S1", "2020-08-01 00:01:05", null, 353.13],
@@ -71,7 +71,7 @@
"expected_skip_nulls": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_bid_pr float, right_ask_pr float",
"ts_col": "left_event_ts",
- "partition_cols": ["symbol"],
+ "series_ids": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21, "2020-08-01 00:00:01", 345.11, 351.12],
@@ -83,7 +83,7 @@
"expected_skip_nulls_disabled": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_bid_pr float, right_ask_pr float",
"ts_col": "left_event_ts",
- "partition_cols": ["symbol"],
+ "series_ids": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21, "2020-08-01 00:00:01", 345.11, 351.12],
@@ -97,7 +97,7 @@
"left": {
"schema": "symbol string, event_ts string, trade_pr float, trade_id int",
"ts_col": "event_ts",
- "partition_cols": ["symbol"],
+ "series_ids": ["symbol"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21, 1],
["S1", "2020-08-01 00:00:10", 350.21, 5],
@@ -109,7 +109,7 @@
"right": {
"schema": "symbol string, event_ts string, bid_pr float, ask_pr float, seq_nb long",
"ts_col": "event_ts",
- "partition_cols": ["symbol"],
+ "series_ids": ["symbol"],
"sequence_col": "seq_nb",
"data": [
["S1", "2020-08-01 00:00:01", 345.11, 351.12, 1],
@@ -123,7 +123,7 @@
"expected": {
"schema": "symbol string, event_ts string, trade_pr float, trade_id int, right_event_ts string, right_bid_pr float, right_ask_pr float, right_seq_nb long",
"ts_col": "event_ts",
- "partition_cols": ["symbol"],
+ "series_ids": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21, 1, "2020-08-01 00:00:10", 19.11, 20.12, 1],
@@ -138,7 +138,7 @@
"left": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
- "partition_cols": ["symbol"],
+ "series_ids": ["symbol"],
"data": [
["S1", "2020-08-01 00:00:02", 349.21],
["S1", "2020-08-01 00:00:08", 351.32],
@@ -152,7 +152,7 @@
"right": {
"schema": "symbol string, event_ts string, bid_pr float, ask_pr float",
"ts_col": "event_ts",
- "partition_cols": ["symbol"],
+ "series_ids": ["symbol"],
"data": [
["S1", "2020-08-01 00:00:01", 345.11, 351.12],
["S1", "2020-08-01 00:00:09", 348.10, 353.13],
@@ -163,7 +163,7 @@
"expected": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_bid_pr float, right_ask_pr float",
"ts_col": "left_event_ts",
- "partition_cols": ["symbol"],
+ "series_ids": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": [
["S1", "2020-08-01 00:00:02", 349.21, "2020-08-01 00:00:01", 345.11, 351.12],
@@ -180,7 +180,7 @@
"left": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
- "partition_cols": ["symbol"],
+ "series_ids": ["symbol"],
"data": [
["S1", "2022-01-01 09:59:59.123456789", 349.21],
["S1", "2022-01-01 10:00:00.123456788", 351.32],
@@ -191,7 +191,7 @@
"right": {
"schema": "symbol string, event_ts string, bid_pr float, ask_pr float",
"ts_col": "event_ts",
- "partition_cols": ["symbol"],
+ "series_ids": ["symbol"],
"data": [
["S1", "2022-01-01 10:00:00.1234567", 345.11, 351.12],
["S1", "2022-01-01 10:00:00.12345671", 348.10, 353.13],
@@ -203,7 +203,7 @@
"expected": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_ask_pr float, right_bid_pr float",
"ts_col": "left_event_ts",
- "partition_cols": ["symbol"],
+ "series_ids": ["symbol"],
"data": [
["S1", "2022-01-01 09:59:59.123456789", 349.21, null, null, null],
["S1", "2022-01-01 10:00:00.123456788", 351.32, "2022-01-01 10:00:00.12345677", 365.33, 358.91],
diff --git a/python/tests/unit_test_data/delta_writer_tests.json b/python/tests/unit_test_data/delta_writer_tests.json
index 752d826b..203b93d2 100644
--- a/python/tests/unit_test_data/delta_writer_tests.json
+++ b/python/tests/unit_test_data/delta_writer_tests.json
@@ -4,7 +4,7 @@
"init": {
"schema": "symbol string, date string, event_ts string, trade_pr float, trade_pr_2 float",
"ts_col": "event_ts",
- "partition_cols": ["symbol"],
+ "series_ids": ["symbol"],
"data": [
["S1", "SAME_DT", "2020-08-01 00:00:10", 349.21, 10.0],
["S1", "SAME_DT", "2020-08-01 00:00:11", 340.21, 9.0],
diff --git a/python/tests/unit_test_data/interpol_tests.json b/python/tests/unit_test_data/interpol_tests.json
index ef8959c3..7548305b 100644
--- a/python/tests/unit_test_data/interpol_tests.json
+++ b/python/tests/unit_test_data/interpol_tests.json
@@ -3,7 +3,7 @@
"input_data": {
"schema": "partition_a string, partition_b string, event_ts string, value_a float, value_b float",
"ts_col": "event_ts",
- "partition_cols": ["partition_a", "partition_b"],
+ "series_ids": ["partition_a", "partition_b"],
"data": [
[
"A",
@@ -66,7 +66,7 @@
"simple_input_data": {
"schema": "partition_a string, partition_b string, event_ts string, value_a float, value_b float",
"ts_col": "event_ts",
- "partition_cols": ["partition_a", "partition_b"],
+ "series_ids": ["partition_a", "partition_b"],
"data": [
[
"A",
@@ -148,7 +148,7 @@
"expected_data": {
"schema": "partition_a string, partition_b string, event_ts string, value_a double, value_b double, is_ts_interpolated boolean, is_interpolated_value_a boolean, is_interpolated_value_b boolean",
"ts_col": "event_ts",
- "partition_cols": ["partition_a", "partition_b"],
+ "series_ids": ["partition_a", "partition_b"],
"data": [
[
"A",
@@ -280,7 +280,7 @@
"expected_data": {
"schema": "partition_a string, partition_b string, event_ts string, value_a double, value_b double, is_ts_interpolated boolean, is_interpolated_value_a boolean, is_interpolated_value_b boolean",
"ts_col": "event_ts",
- "partition_cols": ["partition_a", "partition_b"],
+ "series_ids": ["partition_a", "partition_b"],
"data": [
[
"A",
@@ -412,7 +412,7 @@
"expected_data": {
"schema": "partition_a string, partition_b string, event_ts string, value_a double, value_b double, is_ts_interpolated boolean, is_interpolated_value_a boolean, is_interpolated_value_b boolean",
"ts_col": "event_ts",
- "partition_cols": ["partition_a", "partition_b"],
+ "series_ids": ["partition_a", "partition_b"],
"data": [
[
"A",
@@ -544,7 +544,7 @@
"expected_data": {
"schema": "partition_a string, partition_b string, event_ts string, value_a double, value_b double, is_ts_interpolated boolean, is_interpolated_value_a boolean, is_interpolated_value_b boolean",
"ts_col": "event_ts",
- "partition_cols": ["partition_a", "partition_b"],
+ "series_ids": ["partition_a", "partition_b"],
"data": [
[
"A",
@@ -676,7 +676,7 @@
"expected_data": {
"schema": "partition_a string, partition_b string, event_ts string, value_a double, value_b double, is_ts_interpolated boolean, is_interpolated_value_a boolean, is_interpolated_value_b boolean",
"ts_col": "event_ts",
- "partition_cols": ["partition_a", "partition_b"],
+ "series_ids": ["partition_a", "partition_b"],
"data": [
[
"A",
@@ -808,7 +808,7 @@
"expected_data": {
"schema": "partition_a string, partition_b string, event_ts string, value_a double, value_b double, is_ts_interpolated boolean, is_interpolated_value_a boolean, is_interpolated_value_b boolean",
"ts_col": "event_ts",
- "partition_cols": ["partition_a", "partition_b"],
+ "series_ids": ["partition_a", "partition_b"],
"data": [
[
"A",
@@ -940,7 +940,7 @@
"expected_data": {
"schema": "partition_a string, partition_b string, event_ts string, value_a double, value_b double",
"ts_col": "event_ts",
- "partition_cols": ["partition_a", "partition_b"],
+ "series_ids": ["partition_a", "partition_b"],
"data": [
[
"A",
@@ -1041,7 +1041,7 @@
"expected": {
"schema": "partition_a string, partition_b string, event_ts string, value_a double, value_b double",
"ts_col": "event_ts",
- "partition_cols": ["partition_a", "partition_b"],
+ "series_ids": ["partition_a", "partition_b"],
"data": [
[
"A",
@@ -1140,7 +1140,7 @@
"expected": {
"schema": "partition_a string, partition_b string, other_ts_col string, value_a double, is_ts_interpolated boolean, is_interpolated_value_a boolean",
"ts_col": "other_ts_col",
- "partition_cols": ["partition_a", "partition_b"],
+ "series_ids": ["partition_a", "partition_b"],
"data": [
[
"A",
@@ -1251,7 +1251,7 @@
"expected": {
"schema": "partition_a string, partition_b string, event_ts string, value_a double, is_ts_interpolated boolean, is_interpolated_value_a boolean",
"ts_col": "event_ts",
- "partition_cols": ["partition_a", "partition_b"],
+ "series_ids": ["partition_a", "partition_b"],
"data": [
["A", "A-1", "2020-01-01 00:00:00", 0.0, false, false],
["A", "A-1", "2020-01-01 00:00:30", 1.0, true, true],
@@ -1278,7 +1278,7 @@
"expected": {
"schema": "partition_a string, partition_b string, event_ts string, value_a double, value_b double",
"ts_col": "event_ts",
- "partition_cols": ["partition_a", "partition_b"],
+ "series_ids": ["partition_a", "partition_b"],
"data": [
["A", "A-1", "2020-01-01 00:00:00", 0.0, null],
["A", "A-1", "2020-01-01 00:00:30", 0.0, null],
diff --git a/python/tests/unit_test_data/tsdf_tests.json b/python/tests/unit_test_data/tsdf_tests.json
index c5de5be1..3b522fcd 100644
--- a/python/tests/unit_test_data/tsdf_tests.json
+++ b/python/tests/unit_test_data/tsdf_tests.json
@@ -3,50 +3,16 @@
"temp_slice_init_data": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
- "partition_cols": [
- "symbol"
- ],
+ "series_ids": [ "symbol" ],
"data": [
- [
- "S1",
- "2020-08-01 00:00:10",
- 349.21
- ],
- [
- "S1",
- "2020-08-01 00:01:12",
- 351.32
- ],
- [
- "S1",
- "2020-09-01 00:02:10",
- 361.1
- ],
- [
- "S1",
- "2020-09-01 00:19:12",
- 362.1
- ],
- [
- "S2",
- "2020-08-01 00:01:10",
- 743.01
- ],
- [
- "S2",
- "2020-08-01 00:01:24",
- 751.92
- ],
- [
- "S2",
- "2020-09-01 00:02:10",
- 761.10
- ],
- [
- "S2",
- "2020-09-01 00:20:42",
- 762.33
- ]
+ ["S1", "2020-08-01 00:00:10", 349.21],
+ ["S1", "2020-08-01 00:01:12", 351.32],
+ ["S1", "2020-09-01 00:02:10", 361.1],
+ ["S1", "2020-09-01 00:19:12", 362.1],
+ ["S2", "2020-08-01 00:01:10", 743.01],
+ ["S2", "2020-08-01 00:01:24", 751.92],
+ ["S2", "2020-09-01 00:02:10", 761.10],
+ ["S2", "2020-09-01 00:20:42", 762.33]
]
}
},
@@ -55,30 +21,12 @@
"init": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
- "partition_cols": [
- "symbol"
- ],
+ "series_ids": ["symbol"],
"data": [
- [
- "S1",
- "2020-08-01 00:00:10",
- 349.21
- ],
- [
- "S1",
- "2020-08-01 00:01:12",
- 351.32
- ],
- [
- "S1",
- "2020-09-01 00:02:10",
- 361.1
- ],
- [
- "S1",
- "2020-09-01 00:19:12",
- 362.1
- ]
+ ["S1", "2020-08-01 00:00:10", 349.21],
+ ["S1", "2020-08-01 00:01:12", 351.32],
+ ["S1", "2020-09-01 00:02:10", 361.1],
+ ["S1", "2020-09-01 00:19:12", 362.1]
]
}
},
@@ -89,20 +37,10 @@
"expected": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
- "partition_cols": [
- "symbol"
- ],
+ "series_ids": ["symbol"],
"data": [
- [
- "S1",
- "2020-09-01 00:02:10",
- 361.1
- ],
- [
- "S2",
- "2020-09-01 00:02:10",
- 761.10
- ]
+ ["S1", "2020-09-01 00:02:10", 361.1],
+ ["S2", "2020-09-01 00:02:10", 761.10]
]
}
},
@@ -113,30 +51,12 @@
"expected": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
- "partition_cols": [
- "symbol"
- ],
+ "series_ids": ["symbol"],
"data": [
- [
- "S1",
- "2020-08-01 00:00:10",
- 349.21
- ],
- [
- "S1",
- "2020-08-01 00:01:12",
- 351.32
- ],
- [
- "S2",
- "2020-08-01 00:01:10",
- 743.01
- ],
- [
- "S2",
- "2020-08-01 00:01:24",
- 751.92
- ]
+ ["S1", "2020-08-01 00:00:10", 349.21],
+ ["S1", "2020-08-01 00:01:12", 351.32],
+ ["S2", "2020-08-01 00:01:10", 743.01],
+ ["S2", "2020-08-01 00:01:24", 751.92]
]
}
},
@@ -147,9 +67,7 @@
"expected": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
- "partition_cols": [
- "symbol"
- ],
+ "series_ids": ["symbol"],
"data": [
[
"S1",
@@ -161,26 +79,10 @@
"2020-08-01 00:01:12",
351.32
],
- [
- "S1",
- "2020-09-01 00:02:10",
- 361.1
- ],
- [
- "S2",
- "2020-08-01 00:01:10",
- 743.01
- ],
- [
- "S2",
- "2020-08-01 00:01:24",
- 751.92
- ],
- [
- "S2",
- "2020-09-01 00:02:10",
- 761.10
- ]
+ ["S1", "2020-09-01 00:02:10", 361.1],
+ ["S2", "2020-08-01 00:01:10", 743.01],
+ ["S2", "2020-08-01 00:01:24", 751.92],
+ ["S2", "2020-09-01 00:02:10", 761.10]
]
}
},
@@ -191,7 +93,7 @@
"expected": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -215,7 +117,7 @@
"expected": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -249,7 +151,7 @@
"expected": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -288,7 +190,7 @@
"expected": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -322,7 +224,7 @@
"expected": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -366,7 +268,7 @@
"expected": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -410,7 +312,7 @@
"expected": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -434,7 +336,7 @@
"expected": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -457,7 +359,7 @@
"init": {
"schema": "group string, time long, val double",
"ts_col": "time",
- "partition_cols": [
+ "series_ids": [
"group"
],
"data": [
@@ -506,7 +408,7 @@
"expected": {
"schema": "group string, time long, val double, freq double, ft_real double, ft_imag double",
"ts_col": "time",
- "partition_cols": [
+ "series_ids": [
"group"
],
"data": [
@@ -583,7 +485,7 @@
"init": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -612,7 +514,7 @@
"expected": {
"schema": "symbol string, event_ts string, mean_trade_pr float, count_trade_pr long, min_trade_pr float, max_trade_pr float, sum_trade_pr float, stddev_trade_pr float, zscore_trade_pr float",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -667,7 +569,7 @@
"init": {
"schema": "symbol string, event_ts string, trade_pr float, index integer",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -700,7 +602,7 @@
"expected": {
"schema": "symbol string, event_ts string, mean_trade_pr float, count_trade_pr long, min_trade_pr float, max_trade_pr float, sum_trade_pr float, stddev_trade_pr float, mean_index integer, count_index integer, min_index integer, max_index integer, sum_index integer, stddev_index integer",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -745,7 +647,7 @@
"input": {
"schema": "symbol string, date string, event_ts string, trade_pr float, trade_pr_2 float",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -803,7 +705,7 @@
"expected": {
"schema": "symbol string, event_ts string, floor_trade_pr float, floor_date string, floor_trade_pr_2 float",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -840,7 +742,7 @@
"expected30m": {
"schema": "symbol string, event_ts string, date double, trade_pr double, trade_pr_2 double",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -870,7 +772,7 @@
"expectedbars": {
"schema": "symbol string, event_ts string, close_trade_pr float, close_trade_pr_2 float, high_trade_pr float, high_trade_pr_2 float, low_trade_pr float, low_trade_pr_2 float, open_trade_pr float, open_trade_pr_2 float",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -929,7 +831,7 @@
"init": {
"schema": "symbol string, date string, event_ts string, trade_pr float, trade_pr_2 float",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -959,7 +861,7 @@
"expectedms": {
"schema": "symbol string, event_ts string, date double, trade_pr double, trade_pr_2 double",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -984,7 +886,7 @@
"input": {
"schema": "symbol string, date string, event_ts string, trade_pr float, trade_pr_2 float",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -1042,7 +944,7 @@
"expected": {
"schema": "symbol string, event_ts string, floor_trade_pr float, floor_date string, floor_trade_pr_2 float",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -1079,7 +981,7 @@
"expected30m": {
"schema": "symbol string, event_ts string, date double, trade_pr double, trade_pr_2 double",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -1116,7 +1018,7 @@
"expectedbars": {
"schema": "symbol string, event_ts string, close_trade_pr float, close_trade_pr_2 float, high_trade_pr float, high_trade_pr_2 float, low_trade_pr float, low_trade_pr_2 float, open_trade_pr float, open_trade_pr_2 float",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"symbol"
],
"data": [
@@ -1177,7 +1079,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
@@ -1312,7 +1214,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT, metric_2 FLOAT, metric_3 FLOAT",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
@@ -1440,7 +1342,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
@@ -1575,7 +1477,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
@@ -1631,7 +1533,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
@@ -1773,7 +1675,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
@@ -1829,7 +1731,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
@@ -1964,7 +1866,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
@@ -2020,7 +1922,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
@@ -2155,7 +2057,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
@@ -2211,7 +2113,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
@@ -2360,7 +2262,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
@@ -2416,7 +2318,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
@@ -2551,7 +2453,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT, metric_2 FLOAT, metric_3 FLOAT",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
@@ -2686,7 +2588,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT, metric_2 FLOAT, metric_3 FLOAT",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
@@ -2814,7 +2716,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT, metric_2 FLOAT, metric_3 FLOAT",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
@@ -2920,7 +2822,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
@@ -2942,7 +2844,7 @@
"input": {
"schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL",
"ts_col": "event_ts",
- "partition_cols": [
+ "series_ids": [
"identifier_1",
"identifier_2",
"identifier_3"
diff --git a/python/tests/unit_test_data/utils_tests.json b/python/tests/unit_test_data/utils_tests.json
index b56d5903..7cd9e253 100644
--- a/python/tests/unit_test_data/utils_tests.json
+++ b/python/tests/unit_test_data/utils_tests.json
@@ -4,7 +4,7 @@
"simple_input": {
"schema": "partition_a string, partition_b string, event_ts string, value_a float, value_b float",
"ts_col": "event_ts",
- "partition_cols": ["partition_a", "partition_b"],
+ "series_ids": ["partition_a", "partition_b"],
"data": [
["A", "A-1", "2020-01-01 00:00:10", 0.0, null],
["A", "A-1", "2020-01-01 00:01:10", 2.0, 2.0],
From 5519dafb3ab65606913894fbee57c2bfe6e894c9 Mon Sep 17 00:00:00 2001
From: Tristan Nixon
Date: Mon, 29 Aug 2022 11:33:32 -0700
Subject: [PATCH 03/11] getting tsdf_tests.BasicTests to pass
---
python/tempo/tsdf.py | 19 ++++++-------------
python/tempo/tsschema.py | 10 +++++-----
2 files changed, 11 insertions(+), 18 deletions(-)
diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py
index f124b031..8662986d 100644
--- a/python/tempo/tsdf.py
+++ b/python/tempo/tsdf.py
@@ -18,7 +18,7 @@
import tempo.io as tio
import tempo.resample as rs
from tempo.interpol import Interpolation
-from tempo.tsschema import TSSchema
+from tempo.tsschema import TSIndex, TSSchema
from tempo.utils import (
ENV_CAN_RENDER_HTML,
IS_DATABRICKS,
@@ -68,8 +68,8 @@ def fromDateString(cls, df: DataFrame, ts_col: str, series_ids: Collection[str],
pass
@property
- def ts_index(self) -> str:
- return self.ts_schema.ts_index
+ def ts_index(self) -> "TSIndex":
+ return self.ts_schema.ts_idx
@property
def ts_col(self) -> str:
@@ -886,16 +886,9 @@ def asofJoin(
return asofDF
def __baseWindow(self, sort_col=None, reverse=False):
- # are we ordering forwards (default) or reveresed?
- col_fn = f.col
- if reverse:
- col_fn = lambda colname: f.col(colname).desc() # noqa E731
-
- # our window will be sorted on our sort_cols in the appropriate direction
- if reverse:
- w = Window().orderBy(f.col(self.ts_index).desc())
- else:
- w = Window().orderBy(f.col(self.ts_index).asc())
+ # The index will determine the appropriate sort order
+ w = Window().orderBy(self.ts_index.orderByExpr(reverse))
+
# and partitioned by any series IDs
if self.series_ids:
w = w.partitionBy([f.col(sid) for sid in self.series_ids])
diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py
index 4e425820..81233be8 100644
--- a/python/tempo/tsschema.py
+++ b/python/tempo/tsschema.py
@@ -24,6 +24,10 @@ class TSIndex(ABC):
DoubleType(),
)
+ @classmethod
+ def isValidTSType(cls, dataType: DataType) -> bool:
+ return dataType in cls.__valid_ts_types
+
def __init__(self, name: str, dataType: DataType) -> None:
self.name = name
self.dataType = dataType
@@ -45,7 +49,7 @@ class SimpleTSIndex(TSIndex):
"""
def __init__(self, ts_col: StructField) -> None:
- if ts_col.dataType not in self.__valid_ts_types:
+ if not self.isValidTSType(ts_col.dataType):
raise TypeError(f"DataType {ts_col.dataType} of column {ts_col.name} is not valid for a timeseries Index")
super().__init__(ts_col.name, ts_col.dataType)
@@ -127,10 +131,6 @@ def fromDFSchema(
ts_idx = SimpleTSIndex(df_schema[ts_col])
return cls(ts_idx, series_ids)
- @property
- def ts_index(self) -> str:
- return self.ts_idx.name
-
@property
def structural_columns(self) -> set[str]:
"""
From 608427710c7873b397a1ddddb03fe5372c282782 Mon Sep 17 00:00:00 2001
From: Tristan Nixon
Date: Wed, 31 Aug 2022 12:28:31 -0700
Subject: [PATCH 04/11] big search & replace: partition_cols -> series_ids
getting test code passing
---
python/README.md | 45 ++--
python/tempo/interpol.py | 16 +-
python/tempo/resample.py | 20 +-
python/tempo/tsdf.py | 207 +++++----------
python/tempo/tsschema.py | 267 +++++++++++++++-----
python/tests/base.py | 12 +-
python/tests/interpol_tests.py | 30 +--
python/tests/unit_test_data/tsdf_tests.json | 48 +---
8 files changed, 337 insertions(+), 308 deletions(-)
diff --git a/python/README.md b/python/README.md
index 62dc59c7..01ab82eb 100644
--- a/python/README.md
+++ b/python/README.md
@@ -51,7 +51,7 @@ phone_accel_df = spark.read.format("csv").option("header", "true").load("dbfs:/h
from tempo import *
-phone_accel_tsdf = TSDF(phone_accel_df, ts_col="event_ts", partition_cols = ["User"])
+phone_accel_tsdf = TSDF(phone_accel_df, ts_col="event_ts", series_ids = ["User"])
display(phone_accel_tsdf)
```
@@ -65,7 +65,7 @@ Note: You can upsample any missing values by using an option in the resample int
```python
# ts_col = timestamp column on which to sort fact and source table
-# partition_cols - columns to use for partitioning the TSDF into more granular time series for windowing and sorting
+# series_ids - columns to use for partitioning the TSDF into more granular time series for windowing and sorting
resampled_sdf = phone_accel_tsdf.resample(freq='min', func='floor')
resampled_pdf = resampled_sdf.df.filter(col('event_ts').cast("date") == "2015-02-23").toPandas()
@@ -97,7 +97,7 @@ from pyspark.sql.functions import *
watch_accel_df = spark.read.format("csv").option("header", "true").load("dbfs:/home/tempo/Watch_accelerometer").withColumn("event_ts", (col("Arrival_Time").cast("double")/1000).cast("timestamp")).withColumn("x", col("x").cast("double")).withColumn("y", col("y").cast("double")).withColumn("z", col("z").cast("double")).withColumn("event_ts_dbl", col("event_ts").cast("double"))
-watch_accel_tsdf = TSDF(watch_accel_df, ts_col="event_ts", partition_cols = ["User"])
+watch_accel_tsdf = TSDF(watch_accel_df, ts_col="event_ts", series_ids = ["User"])
# Applying AS OF join to TSDF datasets
joined_df = watch_accel_tsdf.asofJoin(phone_accel_tsdf, right_prefix="phone_accel")
@@ -107,12 +107,12 @@ display(joined_df)
#### 3. Skew Join Optimized AS OF Join
-The purpose of the skew optimized as of join is to bucket each set of `partition_cols` to get the latest source record merged onto the fact table
+The purpose of the skew optimized as of join is to bucket each set of `series_ids` to get the latest source record merged onto the fact table
Parameters:
ts_col = timestamp column for sorting
-partition_cols = partition columns for defining granular time series for windowing and sorting
+series_ids = partition columns for defining granular time series for windowing and sorting
tsPartitionVal = value to break up each partition into time brackets
fraction = overlap fraction
right_prefix = prefix used for source columns when merged into fact table
@@ -185,11 +185,10 @@ Valid columns data types for interpolation are: `["int", "bigint", "float", "dou
```python
# Create instance of the TSDF class
input_tsdf = TSDF(
- input_df,
- partition_cols=["partition_a", "partition_b"],
- ts_col="event_ts",
- )
-
+ input_df,
+ series_ids=["partition_a", "partition_b"],
+ ts_col="event_ts",
+)
# What the following chain of operation does is:
# 1. Aggregate all valid numeric columns using mean into 30 second intervals
@@ -205,32 +204,32 @@ interpolated_tsdf = input_tsdf.resample(freq="30 seconds", func="mean").interpol
interpolated_tsdf = input_tsdf.interpolate(
freq="30 seconds",
func="mean",
- target_cols= ["columnA","columnB"],
+ target_cols=["columnA", "columnB"],
method="linear"
)
# Alternatively it's also possible to override default TSDF parameters.
-# e.g. partition_cols, ts_col a
+# e.g. series_ids, ts_col a
interpolated_tsdf = input_tsdf.interpolate(
- partition_cols=["partition_c"],
+ series_ids=["partition_c"],
ts_col="other_event_ts"
- freq="30 seconds",
- func="mean",
- target_cols= ["columnA","columnB"],
- method="linear"
+freq = "30 seconds",
+ func = "mean",
+ target_cols = ["columnA", "columnB"],
+ method = "linear"
)
# The show_interpolated flag can be set to `True` to show additional boolean columns
# for a given row that shows if a column has been interpolated.
interpolated_tsdf = input_tsdf.interpolate(
- partition_cols=["partition_c"],
+ series_ids=["partition_c"],
ts_col="other_event_ts"
- freq="30 seconds",
- func="mean",
- method="linear",
- target_cols= ["columnA","columnB"],
- show_interpolated=True,
+freq = "30 seconds",
+ func = "mean",
+ method = "linear",
+ target_cols = ["columnA", "columnB"],
+ show_interpolated = True,
)
```
diff --git a/python/tempo/interpol.py b/python/tempo/interpol.py
index fdffb7ba..ac9526de 100644
--- a/python/tempo/interpol.py
+++ b/python/tempo/interpol.py
@@ -265,7 +265,7 @@ def interpolate(
self,
tsdf,
ts_col: str,
- partition_cols: List[str],
+ series_ids: List[str],
target_cols: List[str],
freq: str,
func: str,
@@ -279,7 +279,7 @@ def interpolate(
:param tsdf: input TSDF
:param ts_col: timestamp column name
:param target_cols: numeric columns to interpolate
- :param partition_cols: partition columns names
+ :param series_ids: partition columns names
:param freq: frequency at which to sample
:param func: aggregate function used for sampling to the specified interval
:param method: interpolation function usded to fill missing values
@@ -289,7 +289,7 @@ def interpolate(
"""
# Validate input parameters
self.__validate_fill(method)
- self.__validate_col(tsdf.df, partition_cols, target_cols, ts_col)
+ self.__validate_col(tsdf.df, series_ids, target_cols, ts_col)
# Convert Frequency using resample dictionary
parsed_freq = checkAllowableFreq(freq)
@@ -297,10 +297,10 @@ def interpolate(
# Throw warning for user to validate that the expected number of output rows is valid.
if perform_checks:
- calculate_time_horizon(tsdf.df, ts_col, freq, partition_cols)
+ calculate_time_horizon(tsdf.df, ts_col, freq, series_ids)
# Only select required columns for interpolation
- input_cols: List[str] = [*partition_cols, ts_col, *target_cols]
+ input_cols: List[str] = [*series_ids, ts_col, *target_cols]
sampled_input: DataFrame = tsdf.df.select(*input_cols)
if self.is_resampled is False:
@@ -311,7 +311,7 @@ def interpolate(
# Fill timeseries for nearest values
time_series_filled = self.__generate_time_series_fill(
- sampled_input, partition_cols, ts_col
+ sampled_input, series_ids, ts_col
)
# Generate surrogate timestamps for each target column
@@ -323,7 +323,7 @@ def interpolate(
when(col(column).isNull(), None).otherwise(col(ts_col)),
)
add_column_time = self.__generate_column_time_fill(
- add_column_time, partition_cols, ts_col, column
+ add_column_time, series_ids, ts_col, column
)
# Handle edge case if last value (latest) is null
@@ -338,7 +338,7 @@ def interpolate(
target_column_filled = edge_filled
for column in target_cols:
target_column_filled = self.__generate_target_fill(
- target_column_filled, partition_cols, ts_col, column
+ target_column_filled, series_ids, ts_col, column
)
# Generate missing timeseries values
diff --git a/python/tempo/resample.py b/python/tempo/resample.py
index bf55e495..3b6003f3 100644
--- a/python/tempo/resample.py
+++ b/python/tempo/resample.py
@@ -46,7 +46,7 @@ def __appendAggKey(tsdf, freq=None):
df = df.withColumn("agg_key", agg_window)
return (
- tempo.TSDF(df, tsdf.ts_col, partition_cols=tsdf.partitionCols),
+ tempo.TSDF(df, ts_col=tsdf.ts_col, series_ids=tsdf.series_ids),
parsed_freq[0],
freq_dict[parsed_freq[1]],
)
@@ -65,7 +65,7 @@ def aggregate(tsdf, freq, func, metricCols=None, prefix=None, fill=None):
tsdf, period, unit = __appendAggKey(tsdf, freq)
df = tsdf.df
- groupingCols = tsdf.partitionCols + ["agg_key"]
+ groupingCols = tsdf.series_ids + ["agg_key"]
if metricCols is None:
metricCols = list(set(df.columns).difference(set(groupingCols + [tsdf.ts_col])))
if prefix is None:
@@ -90,7 +90,7 @@ def aggregate(tsdf, freq, func, metricCols=None, prefix=None, fill=None):
res = df.groupBy(groupingCols).agg(exprs)
agg_metric_cls = list(
set(res.columns).difference(
- set(tsdf.partitionCols + [tsdf.ts_col, "agg_key"])
+ set(tsdf.series_ids + [tsdf.ts_col, "agg_key"])
)
)
new_cols = [
@@ -103,7 +103,7 @@ def aggregate(tsdf, freq, func, metricCols=None, prefix=None, fill=None):
res = df.groupBy(groupingCols).agg(exprs)
agg_metric_cls = list(
set(res.columns).difference(
- set(tsdf.partitionCols + [tsdf.ts_col, "agg_key"])
+ set(tsdf.series_ids + [tsdf.ts_col, "agg_key"])
)
)
new_cols = [
@@ -116,7 +116,7 @@ def aggregate(tsdf, freq, func, metricCols=None, prefix=None, fill=None):
res = df.groupBy(groupingCols).agg(exprs)
agg_metric_cls = list(
set(res.columns).difference(
- set(tsdf.partitionCols + [tsdf.ts_col, "agg_key"])
+ set(tsdf.series_ids + [tsdf.ts_col, "agg_key"])
)
)
new_cols = [
@@ -143,15 +143,15 @@ def aggregate(tsdf, freq, func, metricCols=None, prefix=None, fill=None):
)
# sort columns so they are consistent
- non_part_cols = set(set(res.columns) - set(tsdf.partitionCols)) - set([tsdf.ts_col])
- sel_and_sort = tsdf.partitionCols + [tsdf.ts_col] + sorted(non_part_cols)
+ non_part_cols = set(set(res.columns) - set(tsdf.series_ids)) - {tsdf.ts_col}
+ sel_and_sort = tsdf.series_ids + [tsdf.ts_col] + sorted(non_part_cols)
res = res.select(sel_and_sort)
- fillW = Window.partitionBy(tsdf.partitionCols)
+ fillW = Window.partitionBy(tsdf.series_ids)
imputes = (
res.select(
- *tsdf.partitionCols,
+ *tsdf.series_ids,
f.min(tsdf.ts_col).over(fillW).alias("from"),
f.max(tsdf.ts_col).over(fillW).alias("until")
)
@@ -172,7 +172,7 @@ def aggregate(tsdf, freq, func, metricCols=None, prefix=None, fill=None):
if fill:
res = imputes.join(
- res, tsdf.partitionCols + [tsdf.ts_col], "leftouter"
+ res, tsdf.series_ids + [tsdf.ts_col], "leftouter"
).na.fill(0, metrics)
return res
diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py
index 8662986d..d792142c 100644
--- a/python/tempo/tsdf.py
+++ b/python/tempo/tsdf.py
@@ -85,14 +85,11 @@ def structural_cols(self) -> Set[str]:
@property
def observational_cols(self) -> List[str]:
- return [
- col.name
- for col in self.ts_schema.find_observational_columns(self.df.schema)
- ]
+ return list(self.ts_schema.find_observational_columns(self.df.schema))
@property
def metric_cols(self) -> List[str]:
- return [col.name for col in self.ts_schema.find_metric_columns(self.df.schema)]
+ return self.ts_schema.find_metric_columns(self.df.schema)
# def __init__(self, df, ts_col="event_ts", partition_cols=None, sequence_col=None):
# """
@@ -207,15 +204,10 @@ def __addPrefixToColumns(self, col_list, prefix):
if prefix == "":
ts_col = self.ts_col
- seq_col = self.sequence_col if self.sequence_col else self.sequence_col
else:
ts_col = "".join([prefix, self.ts_col])
- seq_col = (
- "".join([prefix, self.sequence_col])
- if self.sequence_col
- else self.sequence_col
- )
- return TSDF(df, ts_col, self.series_ids, sequence_col=seq_col)
+
+ return TSDF(df, ts_col=ts_col, series_ids=self.series_ids)
def __addColumnsFromOtherDF(self, other_cols):
"""
@@ -227,14 +219,14 @@ def __addColumnsFromOtherDF(self, other_cols):
self.df,
)
- return TSDF(new_df, self.ts_col, self.series_ids)
+ return self.__withTransformedDF(new_df)
def __combineTSDF(self, ts_df_right, combined_ts_col):
combined_df = self.df.unionByName(ts_df_right.df).withColumn(
combined_ts_col, f.coalesce(self.ts_col, ts_df_right.ts_col)
)
- return TSDF(combined_df, combined_ts_col, self.series_ids)
+ return TSDF(combined_df, ts_col=combined_ts_col, series_ids=self.series_ids)
def __getLastRightRow(
self,
@@ -332,7 +324,7 @@ def __getLastRightRow(
)
df = df.drop(column)
- return TSDF(df, left_ts_col, self.series_ids)
+ return TSDF(df, ts_col=left_ts_col, series_ids=self.series_ids)
def __getTimePartitions(self, tsPartitionVal, fraction=0.1):
"""
@@ -373,7 +365,7 @@ def __getTimePartitions(self, tsPartitionVal, fraction=0.1):
df = partition_df.union(remainder_df).drop(
"partition_remainder", "ts_col_double"
)
- return TSDF(df, self.ts_col, self.series_ids + ["ts_partition"])
+ return TSDF(df, ts_col=self.ts_col, series_ids=self.series_ids + ["ts_partition"])
#
# Slicing & Selection
@@ -398,15 +390,8 @@ def select(self, *cols):
"""
# The columns which will be a mandatory requirement while selecting from TSDFs
- seq_col_stub = [] if bool(self.sequence_col) is False else [self.sequence_col]
- mandatory_cols = [self.ts_col] + self.series_ids + seq_col_stub
- if set(mandatory_cols).issubset(set(cols)):
- return TSDF(
- self.df.select(*cols),
- self.ts_col,
- self.series_ids,
- self.sequence_col,
- )
+ if set(self.structural_cols).issubset(set(cols)):
+ return self.__withTransformedDF(self.df.select(*cols))
else:
raise Exception(
"In TSDF's select statement original ts_col, partitionCols and seq_col_stub(optional) must be present"
@@ -881,11 +866,11 @@ def asofJoin(
"ts_partition", "is_original"
)
- asofDF = TSDF(df, asofDF.ts_col, combined_df.series_ids)
+ asofDF = TSDF(df, ts_col=asofDF.ts_col, series_ids=combined_df.series_ids)
return asofDF
- def __baseWindow(self, sort_col=None, reverse=False):
+ def __baseWindow(self, reverse=False):
# The index will determine the appropriate sort order
w = Window().orderBy(self.ts_index.orderByExpr(reverse))
@@ -894,14 +879,14 @@ def __baseWindow(self, sort_col=None, reverse=False):
w = w.partitionBy([f.col(sid) for sid in self.series_ids])
return w
- def __rangeBetweenWindow(self, range_from, range_to, sort_col=None, reverse=False):
- return self.__baseWindow(sort_col=sort_col, reverse=reverse).rangeBetween(
- range_from, range_to
- )
-
def __rowsBetweenWindow(self, rows_from, rows_to, reverse=False):
return self.__baseWindow(reverse=reverse).rowsBetween(rows_from, rows_to)
+ def __rangeBetweenWindow(self, range_from, range_to, reverse=False):
+ return ( self.__baseWindow(reverse=reverse)
+ .orderBy(self.ts_index.rangeOrderByExpr(reverse=reverse))
+ .rangeBetween(range_from, range_to ) )
+
def vwap(self, frequency="m", volume_col="volume", price_col="price"):
# set pre_vwap as self or enrich with the frequency
pre_vwap = self.df
@@ -937,7 +922,7 @@ def vwap(self, frequency="m", volume_col="volume", price_col="price"):
.withColumn("vwap", f.col("dllr_value") / f.col(volume_col))
)
- return TSDF(vwapped, self.ts_col, self.series_ids)
+ return self.__withTransformedDF(vwapped)
def EMA(self, colName, window=30, exp_factor=0.2):
"""
@@ -964,7 +949,7 @@ def EMA(self, colName, window=30, exp_factor=0.2):
).drop(lagColName)
# Nulls are currently removed
- return TSDF(df, self.ts_col, self.series_ids)
+ return self.__withTransformedDF(df)
def withLookbackFeatures(
self, featureCols, lookbackWindowSize, exactSize=True, featureColName="features"
@@ -998,7 +983,7 @@ def withLookbackFeatures(
if exactSize:
return lookback_tsdf.where(f.size(featureColName) == lookbackWindowSize)
- return TSDF(lookback_tsdf, self.ts_col, self.series_ids)
+ return self.__withTransformedDF(lookback_tsdf)
def withRangeStats(
self, type="range", colsToSummarize=[], rangeBackWindowSecs=1000
@@ -1018,35 +1003,12 @@ def withRangeStats(
4. There is a cast to long from timestamp so microseconds or more likely breaks down - this could be more easily handled with a string timestamp or sorting the timestamp itself. If using a 'rows preceding' window, this wouldn't be a problem
"""
- # identify columns to summarize if not provided
- # these should include all numeric columns that
- # are not the timestamp column and not any of the partition columns
+ # by default summarize all metric columns
if not colsToSummarize:
- # columns we should never summarize
- prohibited_cols = [self.ts_col.lower()]
- if self.series_ids:
- prohibited_cols.extend([pc.lower() for pc in self.series_ids])
- # types that can be summarized
- summarizable_types = ["int", "bigint", "float", "double"]
- # filter columns to find summarizable columns
- colsToSummarize = [
- datatype[0]
- for datatype in self.df.dtypes
- if (
- (datatype[1] in summarizable_types)
- and (datatype[0].lower() not in prohibited_cols)
- )
- ]
+ colsToSummarize = self.metric_cols
# build window
- if str(self.df.schema[self.ts_col].dataType) == "TimestampType":
- self. __add_double_ts()
- prohibited_cols.extend(["double_ts"])
- w = self.__rangeBetweenWindow(
- -1 * rangeBackWindowSecs, 0, sort_col="double_ts"
- )
- else:
- w = self.__rangeBetweenWindow(-1 * rangeBackWindowSecs, 0)
+ w = self.__rangeBetweenWindow(-1 * rangeBackWindowSecs, 0)
# compute column summaries
selectedCols = self.df.columns
@@ -1069,7 +1031,7 @@ def withRangeStats(
"double_ts"
)
- return TSDF(summary_df, self.ts_col, self.series_ids)
+ return self.__withTransformedDF(summary_df)
def withGroupedStats(self, metricCols=[], freq=None):
"""
@@ -1128,7 +1090,7 @@ def withGroupedStats(self, metricCols=[], freq=None):
.drop("window")
)
- return TSDF(summary_df, self.ts_col, self.series_ids)
+ return self.__withTransformedDF(summary_df)
def write(self, spark, tabName, optimizationCols=None):
tio.write(self, spark, tabName, optimizationCols)
@@ -1176,7 +1138,7 @@ def interpolate(
method: str,
target_cols: List[str] = None,
ts_col: str = None,
- partition_cols: List[str] = None,
+ series_ids: List[str] = None,
show_interpolated: bool = False,
perform_checks: bool = True,
):
@@ -1198,10 +1160,10 @@ def interpolate(
# Set defaults for target columns, timestamp column and partition columns when not provided
if ts_col is None:
ts_col = self.ts_col
- if partition_cols is None:
- partition_cols = self.series_ids
+ if series_ids is None:
+ series_ids = self.series_ids
if target_cols is None:
- prohibited_cols: List[str] = partition_cols + [ts_col]
+ prohibited_cols: List[str] = series_ids + [ts_col]
summarizable_types = ["int", "bigint", "float", "double"]
# get summarizable find summarizable columns
@@ -1215,11 +1177,11 @@ def interpolate(
]
interpolate_service: Interpolation = Interpolation(is_resampled=False)
- tsdf_input = TSDF(self.df, ts_col=ts_col, series_ids=partition_cols)
+ tsdf_input = TSDF(self.df, ts_col=ts_col, series_ids=series_ids)
interpolated_df: DataFrame = interpolate_service.interpolate(
tsdf_input,
ts_col,
- partition_cols,
+ series_ids,
target_cols,
freq,
func,
@@ -1228,7 +1190,7 @@ def interpolate(
perform_checks,
)
- return TSDF(interpolated_df, ts_col=ts_col, series_ids=partition_cols)
+ return TSDF(interpolated_df, ts_col=ts_col, series_ids=series_ids)
def calc_bars(tsdf, freq, func=None, metricCols=None, fill=None):
@@ -1259,7 +1221,7 @@ def calc_bars(tsdf, freq, func=None, metricCols=None, fill=None):
)
bars = bars.select(sel_and_sort)
- return TSDF(bars, resample_open.ts_col, resample_open.series_ids)
+ return TSDF(bars, ts_col=resample_open.ts_col, series_ids=resample_open.series_ids)
def fourier_transform(self, timestep, valueCol):
"""
@@ -1288,77 +1250,39 @@ def tempo_fourier_util(pdf):
valueCol = self.__validated_column(self.df, valueCol)
data = self.df
- if self.sequence_col:
- if self.series_ids == []:
- data = data.withColumn("dummy_group", f.lit("dummy_val"))
- data = (
- data.select(
- f.col("dummy_group"),
- self.ts_col,
- self.sequence_col,
- f.col(valueCol),
- )
- .withColumn("tdval", f.col(valueCol))
- .withColumn("tpoints", f.col(self.ts_col))
- )
- return_schema = ",".join(
- [f"{i[0]} {i[1]}" for i in data.dtypes]
- + ["freq double", "ft_real double", "ft_imag double"]
- )
- result = data.groupBy("dummy_group").applyInPandas(
- tempo_fourier_util, return_schema
- )
- result = result.drop("dummy_group", "tdval", "tpoints")
- else:
- group_cols = self.series_ids
- data = (
- data.select(
- *group_cols, self.ts_col, self.sequence_col, f.col(valueCol)
- )
- .withColumn("tdval", f.col(valueCol))
- .withColumn("tpoints", f.col(self.ts_col))
- )
- return_schema = ",".join(
- [f"{i[0]} {i[1]}" for i in data.dtypes]
- + ["freq double", "ft_real double", "ft_imag double"]
- )
- result = data.groupBy(*group_cols).applyInPandas(
- tempo_fourier_util, return_schema
- )
- result = result.drop("tdval", "tpoints")
+
+ if self.series_ids == []:
+ data = data.withColumn("dummy_group", f.lit("dummy_val"))
+ data = (
+ data.select(f.col("dummy_group"), self.ts_col, f.col(valueCol))
+ .withColumn("tdval", f.col(valueCol))
+ .withColumn("tpoints", f.col(self.ts_col))
+ )
+ return_schema = ",".join(
+ [f"{i[0]} {i[1]}" for i in data.dtypes]
+ + ["freq double", "ft_real double", "ft_imag double"]
+ )
+ result = data.groupBy("dummy_group").applyInPandas(
+ tempo_fourier_util, return_schema
+ )
+ result = result.drop("dummy_group", "tdval", "tpoints")
else:
- if self.series_ids == []:
- data = data.withColumn("dummy_group", f.lit("dummy_val"))
- data = (
- data.select(f.col("dummy_group"), self.ts_col, f.col(valueCol))
- .withColumn("tdval", f.col(valueCol))
- .withColumn("tpoints", f.col(self.ts_col))
- )
- return_schema = ",".join(
- [f"{i[0]} {i[1]}" for i in data.dtypes]
- + ["freq double", "ft_real double", "ft_imag double"]
- )
- result = data.groupBy("dummy_group").applyInPandas(
- tempo_fourier_util, return_schema
- )
- result = result.drop("dummy_group", "tdval", "tpoints")
- else:
- group_cols = self.series_ids
- data = (
- data.select(*group_cols, self.ts_col, f.col(valueCol))
- .withColumn("tdval", f.col(valueCol))
- .withColumn("tpoints", f.col(self.ts_col))
- )
- return_schema = ",".join(
- [f"{i[0]} {i[1]}" for i in data.dtypes]
- + ["freq double", "ft_real double", "ft_imag double"]
- )
- result = data.groupBy(*group_cols).applyInPandas(
- tempo_fourier_util, return_schema
- )
- result = result.drop("tdval", "tpoints")
+ group_cols = self.series_ids
+ data = (
+ data.select(*group_cols, self.ts_col, f.col(valueCol))
+ .withColumn("tdval", f.col(valueCol))
+ .withColumn("tpoints", f.col(self.ts_col))
+ )
+ return_schema = ",".join(
+ [f"{i[0]} {i[1]}" for i in data.dtypes]
+ + ["freq double", "ft_real double", "ft_imag double"]
+ )
+ result = data.groupBy(*group_cols).applyInPandas(
+ tempo_fourier_util, return_schema
+ )
+ result = result.drop("tdval", "tpoints")
- return TSDF(result, self.ts_col, self.series_ids, self.sequence_col)
+ return self.__withTransformedDF(result)
def extractStateIntervals(
self,
@@ -1486,11 +1410,10 @@ def __init__(
df,
ts_col="event_ts",
series_ids=None,
- sequence_col=None,
freq=None,
func=None,
):
- super(_ResampledTSDF, self).__init__(df, ts_col, series_ids, sequence_col)
+ super(_ResampledTSDF, self).__init__(df, ts_col=ts_col, series_ids=series_ids)
self.__freq = freq
self.__func = func
diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py
index 81233be8..7409f3cd 100644
--- a/python/tempo/tsschema.py
+++ b/python/tempo/tsschema.py
@@ -4,96 +4,245 @@
from pyspark.sql import Column
import pyspark.sql.functions as Fn
from pyspark.sql.types import *
+from pyspark.sql.types import NumericType
+#
+# Timeseries Index Classes
+#
class TSIndex(ABC):
"""
Abstract base class for all Timeseries Index types
"""
- # Valid types for time index columns
- __valid_ts_types = (
- DateType(),
- TimestampType(),
- ByteType(),
- ShortType(),
- IntegerType(),
- LongType(),
- DecimalType(),
- FloatType(),
- DoubleType(),
- )
-
- @classmethod
- def isValidTSType(cls, dataType: DataType) -> bool:
- return dataType in cls.__valid_ts_types
+ @property
+ @abstractmethod
+ def name(self):
+ """
+ :return: the column name of the timeseries index
+ """
- def __init__(self, name: str, dataType: DataType) -> None:
- self.name = name
- self.dataType = dataType
+ def _reverseOrNot(self, expr: Union[Column, List[Column]], reverse: bool) -> Union[Column, List[Column]]:
+ if not reverse:
+ return expr # just return the expression as-is if we're not reversing
+ elif type(expr) == Column:
+ return expr.desc() # reverse a single-expression
+ elif type(expr) == List[Column]:
+ return [col.desc() for col in expr] # reverse all columns in the expression
@abstractmethod
def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
"""
- Returns a :class:`Column` expression that will order the :class:`TSDF` according to the timeseries index.
+ Gets an expression that will order the :class:`TSDF` according to the timeseries index.
- :param reverse: whether or not the ordering should be reversed (backwards in time)
+ :param reverse: whether the ordering should be reversed (backwards in time)
:return: an expression appropriate for ordering the :class:`TSDF` according to this index
"""
- pass
-class SimpleTSIndex(TSIndex):
+ def rangeOrderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
+ """
+ Gets an expression appropriate for performing range operations on the :class:`TSDF` records.
+ Defaults to the same expression giving by :method:`TSIndex.orderByExpr`
+
+ :param reverse: whether the ordering should be reversed (backwards in time)
+
+ :return: an expression appropriate for operforming range operations on the :class:`TSDF` records
+ """
+ return self.orderByExpr(reverse=reverse)
+
+#
+# Simple TS Index types
+#
+
+class SimpleTSIndex(TSIndex, ABC):
"""
- Timeseries index based on a single column of a numeric or temporal type.
+ Abstract base class for simple Timeseries Index types
+ that only reference a single column for maintaining the temporal structure
"""
def __init__(self, ts_col: StructField) -> None:
- if not self.isValidTSType(ts_col.dataType):
- raise TypeError(f"DataType {ts_col.dataType} of column {ts_col.name} is not valid for a timeseries Index")
- super().__init__(ts_col.name, ts_col.dataType)
+ self.__name = ts_col.name
+ self.dataType = ts_col.dataType
+
+ @property
+ def name(self):
+ return self.__name
def orderByExpr(self, reverse: bool = False) -> Column:
expr = Fn.col(self.name)
- if reverse:
- return expr.desc()
- return expr
+ return self._reverseOrNot(expr, reverse)
+
+ @classmethod
+ def fromTSCol(cls, ts_col: StructField) -> "SimpleTSIndex":
+ # pick our implementation based on the column type
+ if isinstance(ts_col.dataType, NumericType):
+ return NumericIndex(ts_col)
+ elif isinstance(ts_col.dataType, TimestampType):
+ return SimpleTimestampIndex(ts_col)
+ elif isinstance(ts_col.dataType, DateType):
+ return SimpleDateIndex(ts_col)
+ else:
+ raise TypeError(f"A SimpleTSIndex must be a Numeric, Timestamp or Date type, but column {ts_col.name} is of type {ts_col.dataType}")
-class SubSequenceTSIndex(TSIndex):
+class NumericIndex(SimpleTSIndex):
"""
- Special timeseries index for columns that involve a secondary sequencing column
+ Timeseries index based on a single column of a numeric or temporal type.
"""
- # default name for our timeseries index
- __ts_idx_name = "ts_index"
- # Valid types for sub-sequence columns
- __valid_subseq_types = (
- ByteType(),
- ShortType(),
- IntegerType(),
- LongType()
- )
+ def __init__(self, ts_col: StructField) -> None:
+ if not isinstance(ts_col.dataType, NumericType):
+ raise TypeError(f"NumericIndex must be of a numeric type, but ts_col {ts_col.name} has type {ts_col.dataType}")
+ super().__init__(ts_col)
+
+
+class SimpleTimestampIndex(SimpleTSIndex):
+ """
+ Timeseries index based on a single Timestamp column
+ """
+
+ def __init__(self, ts_col: StructField) -> None:
+ if not isinstance(ts_col.dataType, TimestampType):
+ raise TypeError(f"SimpleTimestampIndex must be of TimestampType, but given ts_col {ts_col.name} has type {ts_col.dataType}")
+ super().__init__(ts_col)
+
+ def rangeOrderByExpr(self, reverse: bool = False) -> Column:
+ # cast timestamp to double (fractional seconds since epoch)
+ expr = Fn.col(self.name).cast("double")
+ return self._reverseOrNot(expr, reverse)
+
+
+class SimpleDateIndex(SimpleTSIndex):
+ """
+ Timeseries index based on a single Date column
+ """
- def __init__(self, primary_ts_col: StructField, subsequence_col: StructField) -> None:
- # validate these column types
- if primary_ts_col.dataType not in self.__valid_ts_types:
- raise TypeError(f"DataType {primary_ts_col.dataType} of column {primary_ts_col.name} is not valid for a timeseries Index")
- if subsequence_col.dataType not in self.__valid_subseq_types:
- raise TypeError(f"DataType {subsequence_col.dataType} of column {subsequence_col.name} is not valid for a sub-sequencing column")
- # construct a struct for these
- ts_struct = StructType([primary_ts_col, subsequence_col])
- super().__init__(self.__ts_idx_name, ts_struct)
- # set colnames for primary & subsequence
- self.primary_ts_col = primary_ts_col.name
- self.subsequence_col = subsequence_col.name
+ def __init__(self, ts_col: StructField) -> None:
+ if not isinstance(ts_col.dataType, DateType):
+ raise TypeError(f"DateIndex must be of DateType, but given ts_col {ts_col.name} has type {ts_col.dataType}")
+ super().__init__(ts_col)
+
+ def rangeOrderByExpr(self, reverse: bool = False) -> Column:
+ # convert date to number of days since the epoch
+ expr = Fn.datediff(Fn.col(self.name), Fn.lit("1970-01-01").cast("date"))
+ self._reverseOrNot(expr,reverse)
+
+#
+# Compound TS Index Types
+#
+
+class CompositeTSIndex(TSIndex, ABC):
+ """
+ Abstract base class for complex Timeseries Index classes
+ that involve two or more columns organized into a StructType column
+ """
+
+ def __init__(self, composite_ts_idx: StructField, primary_ts_col: str) -> None:
+ if not isinstance(composite_ts_idx.dataType, StructType):
+ raise TypeError(f"CompoundTSIndex must be of type StructType, but given compound_ts_idx {composite_ts_idx.name} has type {composite_ts_idx.dataType}")
+ self.ts_idx: str = composite_ts_idx.name
+ self.struct: StructType = composite_ts_idx.dataType
+ # construct a simple TS index object for the primary column
+ self.primary_ts_idx: SimpleTSIndex = SimpleTSIndex.fromTSCol(self.struct[primary_ts_col])
+
+ @property
+ def name(self) -> str:
+ return self.ts_idx
+
+ @property
+ def primary_ts_col(self) -> str:
+ return self.component(self.primary_ts_idx.name)
+
+ def component(self, component_name):
+ """
+ Returns the full path to a component column that is within the composite index
+
+ :param component_name: the name of the component element within the composite index
+
+ :return: a column name that can be used to reference the component column from the :class:`TSDF`
+ """
+ return f"{self.name}.{self.struct[component_name].name}"
+
+ def orderByExpr(self, reverse: bool = False) -> Column:
+ # default to using the primary column
+ expr = Fn.col(self.primary_ts_col)
+ return self._reverseOrNot(expr,reverse)
+
+
+class SubSequenceTSIndex(CompositeTSIndex):
+ """
+ Timeseries Index when we have a primary timeseries column and a secondary sequencing
+ column that indicates the
+ """
+
+ def __init__(self, composite_ts_idx: StructField, primary_ts_col: str, sub_seq_col: str) -> None:
+ super().__init__(composite_ts_idx, primary_ts_col)
+ # construct a simple index for the sub-sequence column
+ self.sub_sequence_idx = NumericIndex(self.struct[sub_seq_col])
+
+ @property
+ def sub_seq_col(self) -> str:
+ return self.component(self.sub_sequence_idx.name)
def orderByExpr(self, reverse: bool = False) -> List[Column]:
- expr = [ Fn.col(self.primary_ts_col), Fn.col(self.subsequence_col) ]
- if reverse:
- return [col.desc() for col in expr]
- return expr
+ # build a composite expression of the primary index followed by the sub-sequence index
+ exprs = [ Fn.col(self.primary_ts_col), Fn.col(self.sub_seq_col) ]
+ return self._reverseOrNot(exprs, reverse)
+
+
+class ParsedTSIndex(CompositeTSIndex, ABC):
+ """
+ Abstract base class for timeseries indices that are parsed from a string column.
+ Retains the original string form as well as the parsed column.
+ """
+
+ def __init__(self, composite_ts_idx: StructField, src_str_col: str, parsed_col: str) -> None:
+ super().__init__(composite_ts_idx, primary_ts_col=parsed_col)
+ src_str_field = self.struct[src_str_col]
+ if not isinstance(src_str_field.dataType, StringType):
+ raise TypeError(f"Source string column must be of StringType, but given column {src_str_field.name} is of type {src_str_field.dataType}")
+ self.__src_str_col = src_str_col
+
+ @property
+ def src_str_col(self):
+ return self.component(self.__src_str_col)
+
+
+class ParsedTimestampIndex(ParsedTSIndex):
+ """
+ Timeseries index class for timestamps parsed from a string column
+ """
+
+ def __init__(self, composite_ts_idx: StructField, src_str_col: str, parsed_col: str) -> None:
+ super().__init__(composite_ts_idx, src_str_col, parsed_col)
+ if not isinstance(self.primary_ts_idx.dataType, TimestampType):
+ raise TypeError(f"ParsedTimestampIndex must be of TimestampType, but given ts_col {self.primary_ts_idx.name} has type {self.primary_ts_idx.dataType}")
+
+ def rangeOrderByExpr(self, reverse: bool = False) -> Column:
+ # cast timestamp to double (fractional seconds since epoch)
+ expr = Fn.col(self.primary_ts_col).cast("double")
+ return self._reverseOrNot(expr, reverse)
+
+
+class ParsedDateIndex(ParsedTSIndex):
+ """
+ Timeseries index class for dates parsed from a string column
+ """
+
+ def __init__(self, composite_ts_idx: StructField, src_str_col: str, parsed_col: str) -> None:
+ super().__init__(composite_ts_idx, src_str_col, parsed_col)
+ if not isinstance(self.primary_ts_idx.dataType, DateType):
+ raise TypeError(f"ParsedDateIndex must be of DateType, but given ts_col {self.primary_ts_idx.name} has type {self.primary_ts_idx.dataType}")
+
+ def rangeOrderByExpr(self, reverse: bool = False) -> Column:
+ # convert date to number of days since the epoch
+ expr = Fn.datediff(Fn.col(self.primary_ts_col), Fn.lit("1970-01-01").cast("date"))
+ self._reverseOrNot(expr,reverse)
+#
+# Timseries Schema
+#
class TSSchema:
"""
@@ -128,7 +277,7 @@ def fromDFSchema(
cls, df_schema: StructType, ts_col: str, series_ids: Collection[str] = None
) -> "TSSchema":
# construct a TSIndex for the given ts_col
- ts_idx = SimpleTSIndex(df_schema[ts_col])
+ ts_idx = SimpleTSIndex.fromTSCol(df_schema[ts_col])
return cls(ts_idx, series_ids)
@property
@@ -139,7 +288,7 @@ def structural_columns(self) -> set[str]:
:return: a set of column names corresponding the structural columns of a :class:`TSDF`
"""
- struct_cols = {self.ts_index}.union(self.series_ids)
+ struct_cols = {self.ts_idx.name}.union(self.series_ids)
struct_cols.discard(None)
return struct_cols
diff --git a/python/tests/base.py b/python/tests/base.py
index 550515c6..34fd6fe1 100644
--- a/python/tests/base.py
+++ b/python/tests/base.py
@@ -134,17 +134,9 @@ def buildTestDF(self, schema, data, ts_cols=["event_ts"]):
# build dataframe
df = self.spark.createDataFrame(data, schema)
- # check if ts_col follows standard timestamp format, then check if timestamp has micro/nanoseconds
+ # convert timstamp fields to timestamp type
for tsc in ts_cols:
- ts_value = str(df.select(ts_cols).limit(1).collect()[0][0])
- ts_pattern = "^\d{4}-\d{2}-\d{2}| \d{2}:\d{2}:\d{2}\.\d*$"
- decimal_pattern = "[.]\d+"
- if re.match(ts_pattern, str(ts_value)) is not None:
- if (
- re.search(decimal_pattern, ts_value) is None
- or len(re.search(decimal_pattern, ts_value)[0]) <= 4
- ):
- df = df.withColumn(tsc, F.to_timestamp(F.col(tsc)))
+ df = df.withColumn(tsc, F.to_timestamp(F.col(tsc)))
return df
#
diff --git a/python/tests/interpol_tests.py b/python/tests/interpol_tests.py
index 91f4dd68..578409a0 100644
--- a/python/tests/interpol_tests.py
+++ b/python/tests/interpol_tests.py
@@ -26,7 +26,7 @@ def test_fill_validation(self):
try:
self.interpolate_helper.interpolate(
tsdf=input_tsdf,
- partition_cols=["partition_a", "partition_b"],
+ series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
@@ -49,7 +49,7 @@ def test_target_column_validation(self):
try:
self.interpolate_helper.interpolate(
tsdf=input_tsdf,
- partition_cols=["partition_a", "partition_b"],
+ series_ids=["partition_a", "partition_b"],
target_cols=["partition_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
@@ -72,7 +72,7 @@ def test_partition_column_validation(self):
try:
self.interpolate_helper.interpolate(
tsdf=input_tsdf,
- partition_cols=["partition_c", "partition_b"],
+ series_ids=["partition_c", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
@@ -95,7 +95,7 @@ def test_ts_column_validation(self):
try:
self.interpolate_helper.interpolate(
tsdf=input_tsdf,
- partition_cols=["partition_a", "partition_b"],
+ series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="value_a",
@@ -123,7 +123,7 @@ def test_zero_fill_interpolation(self):
# interpolate
actual_df: DataFrame = self.interpolate_helper.interpolate(
tsdf=simple_input_tsdf,
- partition_cols=["partition_a", "partition_b"],
+ series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
@@ -149,7 +149,7 @@ def test_null_fill_interpolation(self):
# interpolate
actual_df: DataFrame = self.interpolate_helper.interpolate(
tsdf=simple_input_tsdf,
- partition_cols=["partition_a", "partition_b"],
+ series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
@@ -176,7 +176,7 @@ def test_back_fill_interpolation(self):
# interpolate
actual_df: DataFrame = self.interpolate_helper.interpolate(
tsdf=simple_input_tsdf,
- partition_cols=["partition_a", "partition_b"],
+ series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
@@ -203,7 +203,7 @@ def test_forward_fill_interpolation(self):
# interpolate
actual_df: DataFrame = self.interpolate_helper.interpolate(
tsdf=simple_input_tsdf,
- partition_cols=["partition_a", "partition_b"],
+ series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
@@ -230,7 +230,7 @@ def test_linear_fill_interpolation(self):
# interpolate
actual_df: DataFrame = self.interpolate_helper.interpolate(
tsdf=simple_input_tsdf,
- partition_cols=["partition_a", "partition_b"],
+ series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
@@ -255,7 +255,7 @@ def test_different_freq_abbreviations(self):
# interpolate
actual_df: DataFrame = self.interpolate_helper.interpolate(
tsdf=simple_input_tsdf,
- partition_cols=["partition_a", "partition_b"],
+ series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 sec",
ts_col="event_ts",
@@ -282,7 +282,7 @@ def test_show_interpolated(self):
# interpolate
actual_df: DataFrame = self.interpolate_helper.interpolate(
tsdf=simple_input_tsdf,
- partition_cols=["partition_a", "partition_b"],
+ series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
@@ -322,14 +322,14 @@ def test_interpolation_using_custom_params(self):
input_tsdf = TSDF(
simple_input_tsdf.df.withColumnRenamed("event_ts", "other_ts_col"),
- partition_cols=["partition_a", "partition_b"],
ts_col="other_ts_col",
+ series_ids=["partition_a", "partition_b"]
)
actual_df: DataFrame = input_tsdf.interpolate(
ts_col="other_ts_col",
show_interpolated=True,
- partition_cols=["partition_a", "partition_b"],
+ series_ids=["partition_a", "partition_b"],
target_cols=["value_a"],
freq="30 seconds",
func="mean",
@@ -347,7 +347,7 @@ def test_tsdf_constructor_params_are_updated(self):
actual_tsdf: TSDF = simple_input_tsdf.interpolate(
ts_col="event_ts",
show_interpolated=True,
- partition_cols=["partition_b"],
+ series_ids=["partition_b"],
target_cols=["value_a"],
freq="30 seconds",
func="mean",
@@ -355,7 +355,7 @@ def test_tsdf_constructor_params_are_updated(self):
)
self.assertEqual(actual_tsdf.ts_col, "event_ts")
- self.assertEqual(actual_tsdf.partitionCols, ["partition_b"])
+ self.assertEqual(actual_tsdf.series_ids, ["partition_b"])
def test_interpolation_on_sampled_data(self):
"""Verify interpolation can be chained with resample within the TSDF class"""
diff --git a/python/tests/unit_test_data/tsdf_tests.json b/python/tests/unit_test_data/tsdf_tests.json
index 3b522fcd..95964232 100644
--- a/python/tests/unit_test_data/tsdf_tests.json
+++ b/python/tests/unit_test_data/tsdf_tests.json
@@ -831,54 +831,20 @@
"init": {
"schema": "symbol string, date string, event_ts string, trade_pr float, trade_pr_2 float",
"ts_col": "event_ts",
- "series_ids": [
- "symbol"
- ],
+ "series_ids": ["symbol"],
"data": [
- [
- "S1",
- "SAME_DT",
- "2020-08-01 00:00:10.12345",
- 349.21,
- 10.0
- ],
- [
- "S1",
- "SAME_DT",
- "2020-08-01 00:00:10.123",
- 340.21,
- 9.0
- ],
- [
- "S1",
- "SAME_DT",
- "2020-08-01 00:00:10.124",
- 353.32,
- 8.0
- ]
+ ["S1", "SAME_DT", "2020-08-01 00:00:10.12345", 349.21, 10.0],
+ ["S1", "SAME_DT", "2020-08-01 00:00:10.123", 340.21, 9.0],
+ ["S1", "SAME_DT", "2020-08-01 00:00:10.124", 353.32, 8.0]
]
},
"expectedms": {
"schema": "symbol string, event_ts string, date double, trade_pr double, trade_pr_2 double",
"ts_col": "event_ts",
- "series_ids": [
- "symbol"
- ],
+ "series_ids": ["symbol"],
"data": [
- [
- "S1",
- "2020-08-01 00:00:10.123",
- null,
- 344.71,
- 9.5
- ],
- [
- "S1",
- "2020-08-01 00:00:10.124",
- null,
- 353.32,
- 8.0
- ]
+ ["S1", "2020-08-01 00:00:10.123", null, 344.71, 9.5],
+ ["S1", "2020-08-01 00:00:10.124", null, 353.32, 8.0]
]
}
},
From 0dd3b43a75a6e2982e3e016ec885f4dc23378264 Mon Sep 17 00:00:00 2001
From: Tristan Nixon
Date: Wed, 31 Aug 2022 15:06:18 -0700
Subject: [PATCH 05/11] all as_of tests passing but 1 will need bigger
refactoring...
---
python/tempo/tsdf.py | 54 ++++++++++++++++---
python/tempo/tsschema.py | 17 +++++-
python/tests/as_of_join_tests.py | 6 +++
python/tests/base.py | 5 +-
.../unit_test_data/as_of_join_tests.json | 1 +
5 files changed, 73 insertions(+), 10 deletions(-)
diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py
index d792142c..97b348b9 100644
--- a/python/tempo/tsdf.py
+++ b/python/tempo/tsdf.py
@@ -4,6 +4,7 @@
import operator
from functools import reduce
from typing import List, Union, Callable, Collection, Set
+from copy import deepcopy
import numpy as np
import pyspark.sql.functions as f
@@ -18,7 +19,7 @@
import tempo.io as tio
import tempo.resample as rs
from tempo.interpol import Interpolation
-from tempo.tsschema import TSIndex, TSSchema
+from tempo.tsschema import TSIndex, TSSchema, SubSequenceTSIndex
from tempo.utils import (
ENV_CAN_RENDER_HTML,
IS_DATABRICKS,
@@ -53,11 +54,41 @@ def __init__(
self.ts_schema.validate(df.schema)
def __withTransformedDF(self, new_df: DataFrame) -> "TSDF":
- return TSDF(new_df, ts_schema=self.ts_schema, validate_schema=False)
+ """
+ This helper function will create a new :class:`TSDF` using the current schema, but a new / transformed :class:`DataFrame`
+
+ :param new_df: the new / transformed :class:`DataFrame` to
+
+ :return: a new TSDF object with the transformed DataFrame
+ """
+ return TSDF(new_df, ts_schema=deepcopy(self.ts_schema), validate_schema=False)
+
+ @classmethod
+ def __makeStructFromCols(cls, df: DataFrame, struct_col_name: str, cols_to_move: List[str]) -> DataFrame:
+ """
+ Transform a :class:`DataFrame` by moving certain columns into a struct
+
+ :param df: the :class:`DataFrame` to transform
+ :param struct_col_name: name of the struct column to create
+ :param cols_to_move: name of the columns to move into the struct
+
+ :return: the transformed :class:`DataFrame`
+ """
+ return df.withColumn(struct_col_name, f.struct(cols_to_move)).drop(*cols_to_move)
+
+ __DEFAULT_TS_IDX_COL = "ts_idx"
@classmethod
def fromSubsequenceCol(cls, df: DataFrame, ts_col: str, subsequence_col: str, series_ids: Collection[str] = None) -> "TSDF":
- pass
+ # construct a struct with the ts_col and subsequence_col
+ struct_col_name = cls.__DEFAULT_TS_IDX_COL
+ with_subseq_struct_df = cls.__makeStructFromCols(df, struct_col_name, [ts_col, subsequence_col])
+ # construct an appropriate TSIndex
+ subseq_struct = with_subseq_struct_df.schema[struct_col_name]
+ subseq_idx = SubSequenceTSIndex(subseq_struct, ts_col, subsequence_col)
+ # construct & return the TSDF with appropriate schema
+ return TSDF(with_subseq_struct_df, ts_schema=TSSchema(subseq_idx, series_ids))
+
@classmethod
def fromTimestampString(cls, df: DataFrame, ts_col: str, series_ids: Collection[str] = None, ts_fmt: str = "YYYY-MM-DDThh:mm:ss[.SSSSSS]") -> "TSDF":
@@ -73,7 +104,7 @@ def ts_index(self) -> "TSIndex":
@property
def ts_col(self) -> str:
- return self.ts_index.name
+ return self.ts_index.ts_col
@property
def series_ids(self) -> List[str]:
@@ -180,8 +211,9 @@ def __checkPartitionCols(self, tsdf_right):
)
def __validateTsColMatch(self, right_tsdf):
+ # TODO - can simplify this to get types from schema object
left_ts_datatype = self.df.select(self.ts_col).dtypes[0][1]
- right_ts_datatype = right_tsdf.df.select(self.ts_col).dtypes[0][1]
+ right_ts_datatype = right_tsdf.df.select(right_tsdf.ts_col).dtypes[0][1]
if left_ts_datatype != right_ts_datatype:
raise ValueError(
"left and right dataframe timestamp index columns should have same type"
@@ -243,7 +275,7 @@ def __getLastRightRow(
since it is no longer used in subsequent methods.
"""
ptntl_sort_keys = [self.ts_col, "rec_ind", sequence_col]
- sort_keys = [f.col(col_name) for col_name in ptntl_sort_keys if col_name != ""]
+ sort_keys = [f.col(col_name) for col_name in ptntl_sort_keys if col_name]
window_spec = (
Window.partitionBy(self.series_ids)
@@ -840,10 +872,13 @@ def asofJoin(
# perform asof join.
if tsPartitionVal is None:
+ seq_col = None
+ if isinstance(combined_df.ts_index, SubSequenceTSIndex):
+ seq_col = combined_df.ts_index.sub_seq_col
asofDF = combined_df.__getLastRightRow(
left_tsdf.ts_col,
right_columns,
- right_tsdf.sequence_col,
+ seq_col,
tsPartitionVal,
skipNulls,
suppress_null_warning,
@@ -852,10 +887,13 @@ def asofJoin(
tsPartitionDF = combined_df.__getTimePartitions(
tsPartitionVal, fraction=fraction
)
+ seq_col = None
+ if isinstance(tsPartitionDF.ts_index, SubSequenceTSIndex):
+ seq_col = tsPartitionDF.ts_index.sub_seq_col
asofDF = tsPartitionDF.__getLastRightRow(
left_tsdf.ts_col,
right_columns,
- right_tsdf.sequence_col,
+ seq_col,
tsPartitionVal,
skipNulls,
suppress_null_warning,
diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py
index 7409f3cd..4fa1f091 100644
--- a/python/tempo/tsschema.py
+++ b/python/tempo/tsschema.py
@@ -17,11 +17,18 @@ class TSIndex(ABC):
@property
@abstractmethod
- def name(self):
+ def name(self) -> str:
"""
:return: the column name of the timeseries index
"""
+ @property
+ @abstractmethod
+ def ts_col(self) -> str:
+ """
+ :return: the name of the primary timeseries column (may or may not be the same as the name)
+ """
+
def _reverseOrNot(self, expr: Union[Column, List[Column]], reverse: bool) -> Union[Column, List[Column]]:
if not reverse:
return expr # just return the expression as-is if we're not reversing
@@ -69,6 +76,10 @@ def __init__(self, ts_col: StructField) -> None:
def name(self):
return self.__name
+ @property
+ def ts_col(self) -> str:
+ return self.name
+
def orderByExpr(self, reverse: bool = False) -> Column:
expr = Fn.col(self.name)
return self._reverseOrNot(expr, reverse)
@@ -150,6 +161,10 @@ def __init__(self, composite_ts_idx: StructField, primary_ts_col: str) -> None:
def name(self) -> str:
return self.ts_idx
+ @property
+ def ts_col(self) -> str:
+ return self.primary_ts_col
+
@property
def primary_ts_col(self) -> str:
return self.component(self.primary_ts_idx.name)
diff --git a/python/tests/as_of_join_tests.py b/python/tests/as_of_join_tests.py
index df643518..8e84f237 100644
--- a/python/tests/as_of_join_tests.py
+++ b/python/tests/as_of_join_tests.py
@@ -102,6 +102,12 @@ def test_asof_join_nanos(self):
tsdf_right, left_prefix="left", right_prefix="right"
).df
+ print("joined_df:")
+ joined_df.printSchema()
+
+ print("defExpected:")
+ dfExpected.printSchema()
+
# compare
self.assertDataFramesEqual(joined_df, dfExpected)
diff --git a/python/tests/base.py b/python/tests/base.py
index 34fd6fe1..be5992af 100644
--- a/python/tests/base.py
+++ b/python/tests/base.py
@@ -71,7 +71,10 @@ def get_data_as_sdf(self, name: str, convert_ts_col=True):
def get_data_as_tsdf(self, name: str, convert_ts_col=True):
df = self.get_data_as_sdf(name, convert_ts_col)
td = self.test_data[name]
- tsdf = TSDF(df, ts_col=td["ts_col"], series_ids=td.get("series_ids", None))
+ if "sequence_col" in td:
+ tsdf = TSDF.fromSubsequenceCol(df, td["ts_col"], td["sequence_col"], td.get("series_ids", None))
+ else:
+ tsdf = TSDF(df, ts_col=td["ts_col"], series_ids=td.get("series_ids", None))
return tsdf
TEST_DATA_FOLDER = "unit_test_data"
diff --git a/python/tests/unit_test_data/as_of_join_tests.json b/python/tests/unit_test_data/as_of_join_tests.json
index 879ed220..f68be584 100644
--- a/python/tests/unit_test_data/as_of_join_tests.json
+++ b/python/tests/unit_test_data/as_of_join_tests.json
@@ -203,6 +203,7 @@
"expected": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_ask_pr float, right_bid_pr float",
"ts_col": "left_event_ts",
+ "other_ts_cols": ["right_event_ts"],
"series_ids": ["symbol"],
"data": [
["S1", "2022-01-01 09:59:59.123456789", 349.21, null, null, null],
From c2ef72b431d7c460b8d2d6046c2875abd838becb Mon Sep 17 00:00:00 2001
From: Tristan Nixon
Date: Wed, 14 Sep 2022 10:20:56 -0700
Subject: [PATCH 06/11] checkpoint save of current progress...
---
examples/financial_services_quickstart.py | 4 +-
python/README.md | 16 +-
python/tempo/tsdf.py | 940 +++++++++++-------
python/tempo/tsschema.py | 98 +-
python/tests/as_of_join_tests.py | 18 +-
python/tests/base.py | 9 +-
.../unit_test_data/as_of_join_tests.json | 14 +-
7 files changed, 680 insertions(+), 419 deletions(-)
diff --git a/examples/financial_services_quickstart.py b/examples/financial_services_quickstart.py
index d515bd97..7cfd9064 100644
--- a/examples/financial_services_quickstart.py
+++ b/examples/financial_services_quickstart.py
@@ -125,7 +125,7 @@
# COMMAND ----------
# DBTITLE 1,AS OF Joins - Get Latest Quote Information As Of Time of Trades
-joined_df = trades_tsdf.asofJoin(quotes_tsdf, right_prefix="quote_asof").df
+joined_df = trades_tsdf.asOfJoin(quotes_tsdf, right_prefix="quote_asof").df
display(joined_df.filter(col("symbol") == 'AMH').filter(col("quote_asof_event_ts").isNotNull()))
@@ -136,7 +136,7 @@
logging.getLogger("py4j").setLevel(logging.WARNING)
logging.getLogger("tempo").setLevel(logging.WARNING)
-joined_df = trades_tsdf.asofJoin(quotes_tsdf, tsPartitionVal=30,right_prefix="quote_asof").df
+joined_df = trades_tsdf.asOfJoin(quotes_tsdf, tsPartitionVal=30, right_prefix="quote_asof").df
display(joined_df)
# COMMAND ----------
diff --git a/python/README.md b/python/README.md
index 01ab82eb..4e1a532b 100644
--- a/python/README.md
+++ b/python/README.md
@@ -91,16 +91,20 @@ fig.show()
-
```python
-from pyspark.sql.functions import *
+from pyspark.sql.functions import *
-watch_accel_df = spark.read.format("csv").option("header", "true").load("dbfs:/home/tempo/Watch_accelerometer").withColumn("event_ts", (col("Arrival_Time").cast("double")/1000).cast("timestamp")).withColumn("x", col("x").cast("double")).withColumn("y", col("y").cast("double")).withColumn("z", col("z").cast("double")).withColumn("event_ts_dbl", col("event_ts").cast("double"))
+watch_accel_df = spark.read.format("csv").option("header", "true").load(
+ "dbfs:/home/tempo/Watch_accelerometer").withColumn("event_ts", (col("Arrival_Time").cast("double") / 1000).cast(
+ "timestamp")).withColumn("x", col("x").cast("double")).withColumn("y", col("y").cast("double")).withColumn("z",
+ col("z").cast(
+ "double")).withColumn(
+ "event_ts_dbl", col("event_ts").cast("double"))
-watch_accel_tsdf = TSDF(watch_accel_df, ts_col="event_ts", series_ids = ["User"])
+watch_accel_tsdf = TSDF(watch_accel_df, ts_col="event_ts", series_ids=["User"])
# Applying AS OF join to TSDF datasets
-joined_df = watch_accel_tsdf.asofJoin(phone_accel_tsdf, right_prefix="phone_accel")
+joined_df = watch_accel_tsdf.asOfJoin(phone_accel_tsdf, right_prefix="phone_accel")
display(joined_df)
```
@@ -118,7 +122,7 @@ fraction = overlap fraction
right_prefix = prefix used for source columns when merged into fact table
```python
-joined_df = watch_accel_tsdf.asofJoin(phone_accel_tsdf, right_prefix="watch_accel", tsPartitionVal = 10, fraction = 0.1)
+joined_df = watch_accel_tsdf.asOfJoin(phone_accel_tsdf, right_prefix="watch_accel", tsPartitionVal=10, fraction=0.1)
display(joined_df)
```
diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py
index 97b348b9..d9e370a6 100644
--- a/python/tempo/tsdf.py
+++ b/python/tempo/tsdf.py
@@ -7,7 +7,7 @@
from copy import deepcopy
import numpy as np
-import pyspark.sql.functions as f
+import pyspark.sql.functions as Fn
from IPython.core.display import HTML
from IPython.display import display as ipydisplay
from pyspark.sql import SparkSession
@@ -19,7 +19,7 @@
import tempo.io as tio
import tempo.resample as rs
from tempo.interpol import Interpolation
-from tempo.tsschema import TSIndex, TSSchema, SubSequenceTSIndex
+from tempo.tsschema import TSIndex, TSSchema, SubSequenceTSIndex, SimpleTSIndex
from tempo.utils import (
ENV_CAN_RENDER_HTML,
IS_DATABRICKS,
@@ -29,6 +29,29 @@
logger = logging.getLogger(__name__)
+class TSDFStructureChangeError(Exception):
+ """
+ Error raised when a user attempts an operation that would alter the structure of a TSDF in a destructive manner.
+ """
+ __MSG_TEMPLATE: str = """
+ The attempted operation ({op}) is not allowed because it would result in altering the structure of the TSDF.
+ If you really want to make this change, perform the operation on the underlying DataFrame, then re-create a new TSDF.
+ {d}"""
+
+ def __init__(self, operation: str, details: str = None) -> None:
+ super().__init__(self.__MSG_TEMPLATE.format(op=operation, d=details))
+
+
+class IncompatibleTSError(Exception):
+ """
+ Error raised when an operation is attempted between two incompatible TSDFs.
+ """
+ __MSG_TEMPLATE: str = """
+ The attempted operation ({op}) cannot be performed because the given TSDFs have incompatible structure.
+ {d}"""
+
+ def __init__(self, operation: str, details: str = None) -> None:
+ super().__init__(self.__MSG_TEMPLATE.format(op=operation, d=details))
class TSDF:
"""
@@ -41,7 +64,7 @@ def __init__(
ts_schema: TSSchema = None,
ts_col: str = None,
series_ids: Collection[str] = None,
- validate_schema=True,
+ validate_schema=True
) -> None:
self.df = df
# construct schema if we don't already have one
@@ -53,6 +76,16 @@ def __init__(
if validate_schema:
self.ts_schema.validate(df.schema)
+ def __repr__(self) -> str:
+ return self.__str__()
+
+ def __str__(self) -> str:
+ return f"""TSDF({id(self)}):
+ TS Index: {self.ts_index}
+ Series IDs: {self.series_ids}
+ Observational Cols: {self.observational_cols}
+ DataFrame: {self.df.schema}"""
+
def __withTransformedDF(self, new_df: DataFrame) -> "TSDF":
"""
This helper function will create a new :class:`TSDF` using the current schema, but a new / transformed :class:`DataFrame`
@@ -63,6 +96,19 @@ def __withTransformedDF(self, new_df: DataFrame) -> "TSDF":
"""
return TSDF(new_df, ts_schema=deepcopy(self.ts_schema), validate_schema=False)
+ def __withStandardizedColOrder(self) -> TSDF:
+ """
+ Standardizes the column ordering as such:
+ * series_ids,
+ * ts_index,
+ * observation columns
+
+ :return: a :class:`TSDF` with the columns reordered into "standard order" (as described above)
+ """
+ std_ordered_cols = list(self.series_ids) + [self.ts_index.name] + list(self.observational_cols)
+
+ return self.__withTransformedDF(self.df.select(std_ordered_cols))
+
@classmethod
def __makeStructFromCols(cls, df: DataFrame, struct_col_name: str, cols_to_move: List[str]) -> DataFrame:
"""
@@ -74,8 +120,9 @@ def __makeStructFromCols(cls, df: DataFrame, struct_col_name: str, cols_to_move:
:return: the transformed :class:`DataFrame`
"""
- return df.withColumn(struct_col_name, f.struct(cols_to_move)).drop(*cols_to_move)
+ return df.withColumn(struct_col_name, Fn.struct(cols_to_move)).drop(*cols_to_move)
+ # default column name for constructed timeseries index struct columns
__DEFAULT_TS_IDX_COL = "ts_idx"
@classmethod
@@ -106,46 +153,26 @@ def ts_index(self) -> "TSIndex":
def ts_col(self) -> str:
return self.ts_index.ts_col
+ @property
+ def columns(self) -> List[str]:
+ return self.df.columns
+
@property
def series_ids(self) -> List[str]:
return self.ts_schema.series_ids
@property
- def structural_cols(self) -> Set[str]:
+ def structural_cols(self) -> List[str]:
return self.ts_schema.structural_columns
@property
def observational_cols(self) -> List[str]:
- return list(self.ts_schema.find_observational_columns(self.df.schema))
+ return self.ts_schema.find_observational_columns(self.df.schema)
@property
def metric_cols(self) -> List[str]:
return self.ts_schema.find_metric_columns(self.df.schema)
- # def __init__(self, df, ts_col="event_ts", partition_cols=None, sequence_col=None):
- # """
- # Constructor
- # :param df:
- # :param ts_col:
- # :param partitionCols:
- # :sequence_col every tsdf allows for a tie-breaker secondary sort key
- # """
- # self.ts_col = self.__validated_column(df, ts_col)
- # self.partitionCols = (
- # []
- # if partition_cols is None
- # else self.__validated_columns(df, partition_cols.copy())
- # )
- #
- # self.df = df
- # self.sequence_col = "" if sequence_col is None else sequence_col
- #
- # # Add customized check for string type for the timestamp. If we see a string, we will proactively created a double version of the string timestamp for sorting purposes and rename to ts_col
- # if df.schema[ts_col].dataType == "StringType":
- # sample_ts = df.limit(1).collect()[0][0]
- # self.__validate_ts_string(sample_ts)
- # self.__add_double_ts().withColumnRenamed("double_ts", self.ts_col)
-
#
# Helper functions
#
@@ -156,14 +183,14 @@ def __add_double_ts(self):
self.df.withColumn(
"nanos",
(
- f.when(
- f.col(self.ts_col).contains("."),
- f.concat(f.lit("0."), f.split(f.col(self.ts_col), "\.")[1]),
+ Fn.when(
+ Fn.col(self.ts_col).contains("."),
+ Fn.concat(Fn.lit("0."), Fn.split(Fn.col(self.ts_col), "\.")[1]),
).otherwise(0)
).cast("double"),
)
- .withColumn("long_ts", f.col(self.ts_col).cast("timestamp").cast("long"))
- .withColumn("double_ts", f.col("long_ts") + f.col("nanos"))
+ .withColumn("long_ts", Fn.col(self.ts_col).cast("timestamp").cast("long"))
+ .withColumn("double_ts", Fn.col("long_ts") + Fn.col("nanos"))
.drop("nanos")
.drop("long_ts")
)
@@ -203,68 +230,151 @@ def __validated_columns(self, df, colnames):
self.__validated_column(df, col)
return colnames
- def __checkPartitionCols(self, tsdf_right):
+ #
+ # As-Of Join and associated helper functions
+ #
+
+ def __hasSameSeriesIDs(self, tsdf_right: TSDF):
for left_col, right_col in zip(self.series_ids, tsdf_right.series_ids):
if left_col != right_col:
raise ValueError(
- "left and right dataframe partition columns should have same name in same order"
+ "left and right dataframes must have the same series ID columns, in the same order"
)
- def __validateTsColMatch(self, right_tsdf):
- # TODO - can simplify this to get types from schema object
+ def __validateTsColMatch(self, right_tsdf: TSDF):
left_ts_datatype = self.df.select(self.ts_col).dtypes[0][1]
right_ts_datatype = right_tsdf.df.select(right_tsdf.ts_col).dtypes[0][1]
if left_ts_datatype != right_ts_datatype:
raise ValueError(
- "left and right dataframe timestamp index columns should have same type"
+ "left and right dataframes must have primary time index columns of the same type"
)
- def __addPrefixToColumns(self, col_list, prefix):
+ def __addPrefixToAllColumns(self, prefix: str, include_series_ids=False):
"""
- Add prefix to all specified columns.
+
+ :param prefix:
+ :param include_series_ids:
+ :return:
"""
- if prefix != "":
- prefix = prefix + "_"
- df = reduce(
- lambda df, idx: df.withColumnRenamed(
- col_list[idx], "".join([prefix, col_list[idx]])
- ),
- range(len(col_list)),
- self.df,
- )
+ # no-op if no prefix defined
+ if not prefix or prefix == "":
+ return self
- if prefix == "":
- ts_col = self.ts_col
- else:
- ts_col = "".join([prefix, self.ts_col])
+ # find the columns to prefix
+ cols_to_prefix = self.columns
+ if not include_series_ids:
+ cols_to_prefix = set(cols_to_prefix).difference(self.series_ids)
+
+ # apply a renaming to all
+ renamed_tsdf = reduce(
+ lambda tsdf, col: tsdf.withColumnRenamed( col, f"{prefix}_{col}" ),
+ cols_to_prefix,
+ self
+ ) if len(cols_to_prefix) > 0 else self
- return TSDF(df, ts_col=ts_col, series_ids=self.series_ids)
+ return renamed_tsdf
- def __addColumnsFromOtherDF(self, other_cols):
+ def __prefixedColumnMapping(self, col_list, prefix):
"""
- Add columns from some other DF as lit(None), as pre-step before union.
+ Create an old -> new column name mapping by adding a prefix to all columns in the given list
"""
- new_df = reduce(
- lambda df, idx: df.withColumn(other_cols[idx], f.lit(None)),
- range(len(other_cols)),
- self.df,
- )
- return self.__withTransformedDF(new_df)
+ # no-op if no prefix defined
+ if not prefix or prefix == "":
+ return { col : col for col in col_list }
- def __combineTSDF(self, ts_df_right, combined_ts_col):
- combined_df = self.df.unionByName(ts_df_right.df).withColumn(
- combined_ts_col, f.coalesce(self.ts_col, ts_df_right.ts_col)
- )
+ # otherwise add the prefix
+ return { col : f"{prefix}_{col}" for col in col_list }
- return TSDF(combined_df, ts_col=combined_ts_col, series_ids=self.series_ids)
+ def __renameColumns(self, col_map: dict):
+ """
+ renames columns in this TSDF based on the given mapping
+ """
+
+ renamed_tsdf = reduce(
+ lambda tsdf, colmap: tsdf.withColumnRenamed( colmap[0], colmap[1] ),
+ col_map.items(),
+ self
+ ) if len(col_map) > 0 else self
+
+ return renamed_tsdf
+
+ def __addMissingColumnsFrom(self, other: TSDF) -> "TSDF":
+ """
+ Add missing columns from other TSDF as lit(None), as pre-step before union.
+ """
+ missing_cols = set(other.columns).difference(self.columns)
+ new_tsdf = reduce(
+ lambda tsdf, col: tsdf.withColumn(col, Fn.lit(None)),
+ missing_cols,
+ self,
+ ) if len(missing_cols) > 0 else self
+
+ return new_tsdf
+
+ def __findCommonColumns(self, other: TSDF, include_series_ids = False) -> set[str]:
+ common_cols = set(self.columns).intersection(other.columns)
+ if include_series_ids:
+ return common_cols
+ return common_cols.difference(set(self.series_ids).union(other.series_ids))
+
+ def __combineTSDF(self,
+ right: TSDF,
+ combined_ts_col: str) -> "TSDF":
+ # add all columns missing from each DF
+ left_padded_tsdf = self.__addMissingColumnsFrom(right)
+ right_padded_tsdf = right.__addMissingColumnsFrom(self)
+
+ # next, union them together,
+ combined_df = left_padded_tsdf.df.unionByName(right_padded_tsdf.df)
+
+ # coalesce a combined ts index
+ # special-case logic if one or both of these involve a sub-sequence
+ is_left_subseq = isinstance(self.ts_index, SubSequenceTSIndex)
+ is_right_subseq = isinstance(right.ts_index, SubSequenceTSIndex)
+ if (is_left_subseq or is_right_subseq): # at least one index has a sub-sequence
+ # identify which side has the sub-sequence (or both!)
+ secondary_subseq_expr = Fn.lit(None)
+ if is_left_subseq:
+ primary_subseq_expr = self.ts_index.sub_seq_col
+ if is_right_subseq:
+ secondary_subseq_expr = right.ts_index.sub_seq_col
+ else:
+ primary_subseq_expr = right.ts_index.sub_seq_col
+ # coalesce into a new struct
+ combined_ts_field = "event_ts"
+ combined_subseq_field = "sub_seq"
+ combined_df = combined_df.withColumn(combined_ts_col,
+ Fn.struct(
+ Fn.coalesce(self.ts_index.ts_col,
+ right.ts_index.ts_col).alias(combined_ts_field),
+ Fn.coalesce(primary_subseq_expr,
+ secondary_subseq_expr).alias(combined_subseq_field)
+ ))
+ # construct new SubSequenceTSIndex to represent the combined column
+ combined_ts_struct = combined_df.schema[combined_ts_col]
+ new_ts_index = SubSequenceTSIndex( combined_ts_struct, combined_ts_field, combined_subseq_field)
+ else: # no sub-sequence index, coalesce a simple TS column
+ combined_df = combined_df.withColumn(combined_ts_col,
+ Fn.coalesce(self.ts_col,right.ts_col))
+ new_ts_index = SimpleTSIndex.fromTSCol(combined_df.schema[combined_ts_col])
+
+ # finally, put the columns into a standard order
+ # (series_ids, ts_col, left_cols, right_cols)
+ base_cols = list(self.series_ids) + [combined_ts_col]
+ left_cols = list(set(self.columns).difference(base_cols))
+ right_cols = list(set(right.columns).difference(base_cols))
+
+ # return it as a TSDF
+ new_ts_schema = TSSchema( new_ts_index, self.series_ids )
+ return TSDF( combined_df.select(base_cols + left_cols + right_cols),
+ ts_schema=new_ts_schema )
def __getLastRightRow(
self,
left_ts_col,
right_cols,
- sequence_col,
tsPartitionVal,
ignoreNulls,
suppress_null_warning,
@@ -274,36 +384,52 @@ def __getLastRightRow(
self.ts_col, which is the combined time-stamp column of both left and right dataframe, is dropped at the end
since it is no longer used in subsequent methods.
"""
- ptntl_sort_keys = [self.ts_col, "rec_ind", sequence_col]
- sort_keys = [f.col(col_name) for col_name in ptntl_sort_keys if col_name]
+ # add an indicator column where the left_ts_col might be null
+ left_ts_null_indicator_col = "rec_ind"
+ unreduced_tsdf = self.withColumn(left_ts_null_indicator_col,
+ Fn.when(Fn.col(left_ts_col).isNotNull(), 1).otherwise(-1))
+
+ # build a custom ordering expression with the indicator as *second* sort column
+ # (before any other sub-sequence cols)
+ order_by_expr = unreduced_tsdf.ts_index.orderByExpr()
+ if isinstance(order_by_expr, Column):
+ order_by_expr = [order_by_expr, Fn.col(left_ts_null_indicator_col)]
+ elif isinstance(order_by_expr, list):
+ order_by_expr = [ order_by_expr[0], Fn.col(left_ts_null_indicator_col) ]
+ order_by_expr.extend(order_by_expr[1:])
+ else:
+ raise TypeError(f"Timeseries index's orderByExpr has an unknown type: {type(order_by_expr)}")
+
+ unreduced_tsdf.df.orderBy(order_by_expr).show()
+
+ # build our search window
window_spec = (
- Window.partitionBy(self.series_ids)
- .orderBy(sort_keys)
+ Window.partitionBy(list(unreduced_tsdf.series_ids))
+ .orderBy(order_by_expr)
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
)
if ignoreNulls is False:
if tsPartitionVal is not None:
- raise ValueError(
- "Disabling null skipping with a partition value is not supported yet."
- )
+ raise ValueError("Disabling null skipping with a partition value is not supported yet.")
+
df = reduce(
lambda df, idx: df.withColumn(
right_cols[idx],
- f.last(
- f.when(
- f.col("rec_ind") == -1, f.struct(right_cols[idx])
+ Fn.last(
+ Fn.when(
+ Fn.col(left_ts_null_indicator_col) == -1, Fn.struct(right_cols[idx])
).otherwise(None),
True, # ignore nulls because it indicates rows from the left side
).over(window_spec),
),
range(len(right_cols)),
- self.df,
+ unreduced_tsdf.df,
)
df = reduce(
lambda df, idx: df.withColumn(
- right_cols[idx], f.col(right_cols[idx])[right_cols[idx]]
+ right_cols[idx], Fn.col(right_cols[idx])[right_cols[idx]]
),
range(len(right_cols)),
df,
@@ -313,27 +439,27 @@ def __getLastRightRow(
df = reduce(
lambda df, idx: df.withColumn(
right_cols[idx],
- f.last(right_cols[idx], ignoreNulls).over(window_spec),
+ Fn.last(right_cols[idx], ignoreNulls).over(window_spec),
),
range(len(right_cols)),
- self.df,
+ unreduced_tsdf.df,
)
else:
df = reduce(
lambda df, idx: df.withColumn(
right_cols[idx],
- f.last(right_cols[idx], ignoreNulls).over(window_spec),
+ Fn.last(right_cols[idx], ignoreNulls).over(window_spec),
).withColumn(
"non_null_ct" + right_cols[idx],
- f.count(right_cols[idx]).over(window_spec),
+ Fn.count(right_cols[idx]).over(window_spec),
),
range(len(right_cols)),
- self.df,
+ unreduced_tsdf.df,
)
- df = (df.filter(f.col(left_ts_col).isNotNull()).drop(self.ts_col)).drop(
- "rec_ind"
- )
+ df = (df.filter(Fn.col(left_ts_col).isNotNull())
+ .drop(unreduced_tsdf.ts_col)
+ .drop(left_ts_null_indicator_col))
# remove the null_ct stats used to record missing values in partitioned as of join
if tsPartitionVal is not None:
@@ -356,7 +482,7 @@ def __getLastRightRow(
)
df = df.drop(column)
- return TSDF(df, ts_col=left_ts_col, series_ids=self.series_ids)
+ return TSDF(df, ts_col=left_ts_col, series_ids=self.series_ids).__withStandardizedColOrder()
def __getTimePartitions(self, tsPartitionVal, fraction=0.1):
"""
@@ -372,26 +498,26 @@ def __getTimePartitions(self, tsPartitionVal, fraction=0.1):
"""
partition_df = (
self.df.withColumn(
- "ts_col_double", f.col(self.ts_col).cast("double")
+ "ts_col_double", Fn.col(self.ts_col).cast("double")
) # double is preferred over unix_timestamp
.withColumn(
"ts_partition",
- f.lit(tsPartitionVal)
- * (f.col("ts_col_double") / f.lit(tsPartitionVal)).cast("integer"),
+ Fn.lit(tsPartitionVal)
+ * (Fn.col("ts_col_double") / Fn.lit(tsPartitionVal)).cast("integer"),
)
.withColumn(
"partition_remainder",
- (f.col("ts_col_double") - f.col("ts_partition"))
- / f.lit(tsPartitionVal),
+ (Fn.col("ts_col_double") - Fn.col("ts_partition"))
+ / Fn.lit(tsPartitionVal),
)
- .withColumn("is_original", f.lit(1))
+ .withColumn("is_original", Fn.lit(1))
).cache() # cache it because it's used twice.
# add [1 - fraction] of previous time partition to the next partition.
remainder_df = (
- partition_df.filter(f.col("partition_remainder") >= f.lit(1 - fraction))
- .withColumn("ts_partition", f.col("ts_partition") + f.lit(tsPartitionVal))
- .withColumn("is_original", f.lit(0))
+ partition_df.filter(Fn.col("partition_remainder") >= Fn.lit(1 - fraction))
+ .withColumn("ts_partition", Fn.col("ts_partition") + Fn.lit(tsPartitionVal))
+ .withColumn("is_original", Fn.lit(0))
)
df = partition_df.union(remainder_df).drop(
@@ -399,6 +525,207 @@ def __getTimePartitions(self, tsPartitionVal, fraction=0.1):
)
return TSDF(df, ts_col=self.ts_col, series_ids=self.series_ids + ["ts_partition"])
+
+ def __getBytesFromPlan(self, df: DataFrame, spark: SparkSession):
+ """
+ Internal helper function to obtain how many bytes in memory the Spark data frame is likely to take up. This is an upper bound and is obtained from the plan details in Spark
+
+ Parameters
+ :param df - input Spark data frame - the AS OF join has 2 data frames; this will be called for each
+ """
+
+ df.createOrReplaceTempView("view")
+ plan = spark.sql("explain cost select * from view").collect()[0][0]
+
+ import re
+
+ result = (
+ re.search(r"sizeInBytes=.*(['\)])", plan, re.MULTILINE)
+ .group(0)
+ .replace(")", "")
+ )
+ size = result.split("=")[1].split(" ")[0]
+ units = result.split("=")[1].split(" ")[1]
+
+ # perform to MB for threshold check
+ if units == "GiB":
+ bytes = float(size) * 1024 * 1024 * 1024
+ elif units == "MiB":
+ bytes = float(size) * 1024 * 1024
+ elif units == "KiB":
+ bytes = float(size) * 1024
+ else:
+ bytes = float(size)
+
+ return bytes
+
+ def __broadcastAsOfJoin(self,
+ right: TSDF,
+ left_prefix: str,
+ right_prefix: str) -> TSDF:
+
+ # prefix all columns that share common names, except for series IDs
+ common_non_series_cols = self.__findCommonColumns(right)
+ left_prefixed_tsdf = self.__prefixedColumnMapping(common_non_series_cols, left_prefix)
+ right_prefixed_tsdf = right.__prefixedColumnMapping(common_non_series_cols, right_prefix)
+
+ # build an "upper bound" for the join on the right-hand ts column
+ right_ts_col = right_prefixed_tsdf.ts_col
+ upper_bound_ts_col = "upper_bound_"+ right_ts_col
+ max_ts = "9999-12-31"
+ w = right_prefixed_tsdf.__baseWindow()
+ right_w_upper_bound = right_prefixed_tsdf.withColumn(upper_bound_ts_col,
+ Fn.coalesce(
+ Fn.lead(right_ts_col).over(w),
+ Fn.lit(max_ts).cast("timestamp")))
+
+ # perform join
+ left_ts_col = left_prefixed_tsdf.ts_col
+ series_ids = left_prefixed_tsdf.series_ids
+ res = (
+ left_prefixed_tsdf.df
+ .join(right_w_upper_bound.df, list(series_ids))
+ .where(left_prefixed_tsdf[left_ts_col].between(Fn.col(right_ts_col),
+ Fn.col(upper_bound_ts_col)))
+ .drop(upper_bound_ts_col)
+ )
+
+ # return as new TSDF
+ return TSDF(res, ts_col=left_ts_col, series_ids=series_ids)
+
+ def __skewAsOfJoin(self,
+ right: TSDF,
+ left_prefix: str,
+ right_prefix: str,
+ tsPartitionVal,
+ fraction=0.1,
+ skipNulls: bool = True,
+ suppress_null_warning: bool = False) -> TSDF:
+ logger.warning(
+ "You are using the skew version of the AS OF join. "
+ "This may result in null values if there are any values outside of the maximum lookback. "
+ "For maximum efficiency, choose smaller values of maximum lookback, "
+ "trading off performance and potential blank AS OF values for sparse keys"
+ )
+ # prefix all columns except for series IDs
+ left_prefixed_tsdf = self.__addPrefixToAllColumns(left_prefix)
+ right_prefixed_tsdf = right.__addPrefixToAllColumns(right_prefix)
+
+
+ # Union both dataframes, and create a combined TS column
+ combined_ts_col = "combined_ts"
+ combined_tsdf = left_prefixed_tsdf.__combineTSDF(right_prefixed_tsdf, combined_ts_col)
+ print(f"combined tsdf: {combined_tsdf}")
+
+ # set up time partitions
+ tsPartitionDF = combined_tsdf.__getTimePartitions(tsPartitionVal,
+ fraction=fraction)
+ print(f"tsPartitionDF: {tsPartitionDF}")
+
+ # resolve correct right-hand rows
+ right_cols = list(set(right_prefixed_tsdf.columns).difference(combined_tsdf.series_ids))
+ asofDF = tsPartitionDF.__getLastRightRow(
+ left_prefixed_tsdf.ts_col,
+ right_cols,
+ tsPartitionVal,
+ skipNulls,
+ suppress_null_warning,
+ )
+ print(f"asofDF: {asofDF}")
+
+ # Get rid of overlapped data and the extra columns generated from timePartitions
+ df = ( asofDF.df.filter(Fn.col("is_original") == 1)
+ .drop("ts_partition", "is_original"))
+
+ return TSDF(df, ts_col=asofDF.ts_col, series_ids=asofDF.series_ids)
+
+ def __standardAsOfJoin(self,
+ right: TSDF,
+ left_prefix: str,
+ right_prefix: str,
+ skipNulls: bool = True,
+ suppress_null_warning: bool = False) -> TSDF:
+ # prefix all columns except for series IDs
+ left_prefixed_tsdf = self.__addPrefixToAllColumns(left_prefix)
+ right_prefixed_tsdf = right.__addPrefixToAllColumns(right_prefix)
+
+ # Union both dataframes, and create a combined TS column
+ combined_ts_col = "combined_ts"
+ combined_tsdf = left_prefixed_tsdf.__combineTSDF(right_prefixed_tsdf, combined_ts_col)
+
+ # resolve correct right-hand rows
+ right_cols = list(set(right_prefixed_tsdf.columns).difference(combined_tsdf.series_ids))
+ asofDF = combined_tsdf.__getLastRightRow(
+ left_prefixed_tsdf.ts_col,
+ right_cols,
+ None,
+ skipNulls,
+ suppress_null_warning,
+ )
+
+ return asofDF
+
+ def asOfJoin(
+ self,
+ right: TSDF,
+ left_prefix: str = None,
+ right_prefix: str = "right",
+ tsPartitionVal = None,
+ fraction: float = 0.5,
+ skipNulls: bool = True,
+ sql_join_opt: bool = False,
+ suppress_null_warning: bool = False,
+ ):
+ """
+ Performs an as-of join between two time-series. If a tsPartitionVal is specified, it will do this partitioned by
+ time brackets, which can help alleviate skew.
+
+ NOTE: partition cols have to be the same for both Dataframes. We are collecting stats when the WARNING level is
+ enabled also.
+
+ Parameters
+ :param right - right-hand data frame containing columns to merge in
+ :param left_prefix - optional prefix for base data frame
+ :param right_prefix - optional prefix for right-hand data frame
+ :param tsPartitionVal - value to partition each series into time brackets
+ :param fraction - overlap fraction
+ :param skipNulls - whether to skip nulls when joining in values
+ :param sql_join_opt - if set to True, will use standard Spark SQL join if it is estimated to be efficient
+ :param suppress_null_warning - when tsPartitionVal is specified, will collect min of each column and raise warnings about null values, set to True to avoid
+ """
+
+ # Check whether partition columns have the same name in both dataframes
+ self.__hasSameSeriesIDs(right)
+
+ # validate timestamp datatypes match
+ self.__validateTsColMatch(right)
+
+ # execute the broadcast-join variation
+ # choose 30MB as the cutoff for the broadcast
+ bytes_threshold = 30 * 1024 * 1024
+ spark = SparkSession.builder.getOrCreate()
+ left_bytes = self.__getBytesFromPlan(self.df, spark)
+ right_bytes = self.__getBytesFromPlan(right.df, spark)
+ if sql_join_opt & ((left_bytes < bytes_threshold)
+ | (right_bytes < bytes_threshold)):
+ spark.conf.set("spark.databricks.optimizer.rangeJoin.binSize", "60")
+ return self.__broadcastAsOfJoin(right)
+
+ # perform as-of join.
+ if tsPartitionVal is None:
+ return self.__standardAsOfJoin(right,
+ left_prefix,
+ right_prefix,
+ skipNulls,
+ suppress_null_warning)
+ else:
+ return self.__skewAsOfJoin(right,
+ left_prefix,
+ right_prefix,
+ tsPartitionVal,
+ skipNulls=skipNulls,
+ suppress_null_warning=suppress_null_warning)
+
#
# Slicing & Selection
#
@@ -425,9 +752,7 @@ def select(self, *cols):
if set(self.structural_cols).issubset(set(cols)):
return self.__withTransformedDF(self.df.select(*cols))
else:
- raise Exception(
- "In TSDF's select statement original ts_col, partitionCols and seq_col_stub(optional) must be present"
- )
+ raise TSDFStructureChangeError("select that does not include all structural columns")
def __slice(self, op: str, target_ts):
"""
@@ -441,7 +766,7 @@ def __slice(self, op: str, target_ts):
"""
# quote our timestamp if its a string
target_expr = f"'{target_ts}'" if isinstance(target_ts, str) else target_ts
- slice_expr = f.expr(f"{self.ts_col} {op} {target_expr}")
+ slice_expr = Fn.expr(f"{self.ts_col} {op} {target_expr}")
sliced_df = self.df.where(slice_expr)
return self.__withTransformedDF(sliced_df)
@@ -521,8 +846,8 @@ def __top_rows_per_series(self, win: WindowSpec, n: int):
"""
row_num_col = "__row_num"
prev_records_df = (
- self.df.withColumn(row_num_col, f.row_number().over(win))
- .where(f.col(row_num_col) <= f.lit(n))
+ self.df.withColumn(row_num_col, Fn.row_number().over(win))
+ .where(Fn.col(row_num_col) <= Fn.lit(n))
.drop(row_num_col)
)
return self.__withTransformedDF(prev_records_df)
@@ -626,29 +951,29 @@ def describe(self):
# extract the double version of the timestamp column to summarize
double_ts_col = self.ts_col + "_dbl"
- this_df = self.df.withColumn(double_ts_col, f.col(self.ts_col).cast("double"))
+ this_df = self.df.withColumn(double_ts_col, Fn.col(self.ts_col).cast("double"))
# summary missing value percentages
missing_vals = this_df.select(
[
(
100
- * f.count(f.when(f.col(c[0]).isNull(), c[0]))
- / f.count(f.lit(1))
+ * Fn.count(Fn.when(Fn.col(c[0]).isNull(), c[0]))
+ / Fn.count(Fn.lit(1))
).alias(c[0])
for c in this_df.dtypes
if c[1] != "timestamp"
]
- ).select(f.lit("missing_vals_pct").alias("summary"), "*")
+ ).select(Fn.lit("missing_vals_pct").alias("summary"), "*")
# describe stats
desc_stats = this_df.describe().union(missing_vals)
unique_ts = this_df.select(*self.series_ids).distinct().count()
- max_ts = this_df.select(f.max(f.col(self.ts_col)).alias("max_ts")).collect()[0][
+ max_ts = this_df.select(Fn.max(Fn.col(self.ts_col)).alias("max_ts")).collect()[0][
0
]
- min_ts = this_df.select(f.min(f.col(self.ts_col)).alias("max_ts")).collect()[0][
+ min_ts = this_df.select(Fn.min(Fn.col(self.ts_col)).alias("max_ts")).collect()[0][
0
]
gran = this_df.selectExpr(
@@ -664,22 +989,22 @@ def describe(self):
non_summary_cols = [c for c in desc_stats.columns if c != "summary"]
desc_stats = desc_stats.select(
- f.col("summary"),
- f.lit(" ").alias("unique_ts_count"),
- f.lit(" ").alias("min_ts"),
- f.lit(" ").alias("max_ts"),
- f.lit(" ").alias("granularity"),
+ Fn.col("summary"),
+ Fn.lit(" ").alias("unique_ts_count"),
+ Fn.lit(" ").alias("min_ts"),
+ Fn.lit(" ").alias("max_ts"),
+ Fn.lit(" ").alias("granularity"),
*non_summary_cols,
)
# add in single record with global summary attributes and the previously computed missing value and Spark data frame describe stats
global_smry_rec = desc_stats.limit(1).select(
- f.lit("global").alias("summary"),
- f.lit(unique_ts).alias("unique_ts_count"),
- f.lit(min_ts).alias("min_ts"),
- f.lit(max_ts).alias("max_ts"),
- f.lit(gran).alias("granularity"),
- *[f.lit(" ").alias(c) for c in non_summary_cols],
+ Fn.lit("global").alias("summary"),
+ Fn.lit(unique_ts).alias("unique_ts_count"),
+ Fn.lit(min_ts).alias("min_ts"),
+ Fn.lit(max_ts).alias("max_ts"),
+ Fn.lit(gran).alias("granularity"),
+ *[Fn.lit(" ").alias(c) for c in non_summary_cols],
)
full_smry = global_smry_rec.union(desc_stats)
@@ -694,236 +1019,83 @@ def describe(self):
return full_smry
pass
- def __getBytesFromPlan(self, df, spark):
- """
- Internal helper function to obtain how many bytes in memory the Spark data frame is likely to take up. This is an upper bound and is obtained from the plan details in Spark
-
- Parameters
- :param df - input Spark data frame - the AS OF join has 2 data frames; this will be called for each
- :param spark - Spark session which is used to query the view obtained from the Spark data frame
- """
+ #
+ # Window helper functions
+ #
- df.createOrReplaceTempView("view")
- plan = spark.sql("explain cost select * from view").collect()[0][0]
+ def __baseWindow(self, reverse=False):
+ # The index will determine the appropriate sort order
+ w = Window().orderBy(self.ts_index.orderByExpr(reverse))
- import re
+ # and partitioned by any series IDs
+ if self.series_ids:
+ w = w.partitionBy([Fn.col(sid) for sid in self.series_ids])
+ return w
- result = (
- re.search(r"sizeInBytes=.*(['\)])", plan, re.MULTILINE)
- .group(0)
- .replace(")", "")
- )
- size = result.split("=")[1].split(" ")[0]
- units = result.split("=")[1].split(" ")[1]
+ def __rowsBetweenWindow(self, rows_from, rows_to, reverse=False):
+ return self.__baseWindow(reverse=reverse).rowsBetween(rows_from, rows_to)
- # perform to MB for threshold check
- if units == "GiB":
- bytes = float(size) * 1024 * 1024 * 1024
- elif units == "MiB":
- bytes = float(size) * 1024 * 1024
- elif units == "KiB":
- bytes = float(size) * 1024
- else:
- bytes = float(size)
+ def __rangeBetweenWindow(self, range_from, range_to, reverse=False):
+ return ( self.__baseWindow(reverse=reverse)
+ .orderBy(self.ts_index.rangeOrderByExpr(reverse=reverse))
+ .rangeBetween(range_from, range_to ) )
- return bytes
+ #
+ # Core Transformations
+ #
- def asofJoin(
- self,
- right_tsdf,
- left_prefix=None,
- right_prefix="right",
- tsPartitionVal=None,
- fraction=0.5,
- skipNulls=True,
- sql_join_opt=False,
- suppress_null_warning=False,
- ):
+ def withColumn(self, colName: str, col: Column) -> "TSDF":
"""
- Performs an as-of join between two time-series. If a tsPartitionVal is specified, it will do this partitioned by
- time brackets, which can help alleviate skew.
+ Returns a new :class:`TSDF` by adding a column or replacing the existing column that has the same name.
- NOTE: partition cols have to be the same for both Dataframes. We are collecting stats when the WARNING level is
- enabled also.
-
- Parameters
- :param right_tsdf - right-hand data frame containing columns to merge in
- :param left_prefix - optional prefix for base data frame
- :param right_prefix - optional prefix for right-hand data frame
- :param tsPartitionVal - value to break up each partition into time brackets
- :param fraction - overlap fraction
- :param skipNulls - whether to skip nulls when joining in values
- :param sql_join_opt - if set to True, will use standard Spark SQL join if it is estimated to be efficient
- :param suppress_null_warning - when tsPartitionVal is specified, will collect min of each column and raise warnings about null values, set to True to avoid
+ :param colName: the name of the new column (or existing column to be replaced)
+ :param col: a :class:`Column` expression for the new column definition
"""
+ if colName in self.structural_cols:
+ raise TSDFStructureChangeError(f"withColumn on the structural column {colName}.")
+ new_df = self.df.withColumn(colName, col)
+ return self.__withTransformedDF(new_df)
- # first block of logic checks whether a standard range join will suffice
- left_df = self.df
- right_df = right_tsdf.df
-
- spark = SparkSession.builder.getOrCreate()
- left_bytes = self.__getBytesFromPlan(left_df, spark)
- right_bytes = self.__getBytesFromPlan(right_df, spark)
-
- # choose 30MB as the cutoff for the broadcast
- bytes_threshold = 30 * 1024 * 1024
- if sql_join_opt & (
- (left_bytes < bytes_threshold) | (right_bytes < bytes_threshold)
- ):
- spark.conf.set("spark.databricks.optimizer.rangeJoin.binSize", 60)
- partition_cols = right_tsdf.series_ids
- left_cols = list(set(left_df.columns).difference(set(self.series_ids)))
- right_cols = list(
- set(right_df.columns).difference(set(right_tsdf.series_ids))
- )
-
- left_prefix = (
- ""
- if ((left_prefix is None) | (left_prefix == ""))
- else left_prefix + "_"
- )
- right_prefix = (
- ""
- if ((right_prefix is None) | (right_prefix == ""))
- else right_prefix + "_"
- )
-
- w = Window.partitionBy(*partition_cols).orderBy(
- right_prefix + right_tsdf.ts_col
- )
-
- new_left_ts_col = left_prefix + self.ts_col
- new_left_cols = [
- f.col(c).alias(left_prefix + c) for c in left_cols
- ] + partition_cols
- new_right_cols = [
- f.col(c).alias(right_prefix + c) for c in right_cols
- ] + partition_cols
- quotes_df_w_lag = right_df.select(*new_right_cols).withColumn(
- "lead_" + right_tsdf.ts_col,
- f.lead(right_prefix + right_tsdf.ts_col).over(w),
- )
- left_df = left_df.select(*new_left_cols)
- res = (
- left_df.join(quotes_df_w_lag, partition_cols)
- .where(
- left_df[new_left_ts_col].between(
- f.col(right_prefix + right_tsdf.ts_col),
- f.coalesce(
- f.col("lead_" + right_tsdf.ts_col),
- f.lit("2099-01-01").cast("timestamp"),
- ),
- )
- )
- .drop("lead_" + right_tsdf.ts_col)
- )
- return TSDF(res, series_ids=self.series_ids, ts_col=new_left_ts_col)
-
- # end of block checking to see if standard Spark SQL join will work
-
- if tsPartitionVal is not None:
- logger.warning(
- "You are using the skew version of the AS OF join. This may result in null values if there are any values outside of the maximum lookback. For maximum efficiency, choose smaller values of maximum lookback, trading off performance and potential blank AS OF values for sparse keys"
- )
-
- # Check whether partition columns have same name in both dataframes
- self.__checkPartitionCols(right_tsdf)
-
- # prefix non-partition columns, to avoid duplicated columns.
- left_df = self.df
- right_df = right_tsdf.df
-
- # validate timestamp datatypes match
- self.__validateTsColMatch(right_tsdf)
-
- orig_left_col_diff = list(set(left_df.columns).difference(set(self.series_ids)))
- orig_right_col_diff = list(
- set(right_df.columns).difference(set(self.series_ids))
- )
-
- left_tsdf = (
- (self.__addPrefixToColumns([self.ts_col] + orig_left_col_diff, left_prefix))
- if left_prefix is not None
- else self
- )
- right_tsdf = right_tsdf.__addPrefixToColumns(
- [right_tsdf.ts_col] + orig_right_col_diff, right_prefix
- )
-
- left_nonpartition_cols = list(
- set(left_tsdf.df.columns).difference(set(self.series_ids))
- )
- right_nonpartition_cols = list(
- set(right_tsdf.df.columns).difference(set(self.series_ids))
- )
-
- # For both dataframes get all non-partition columns (including ts_col)
- left_columns = [left_tsdf.ts_col] + left_nonpartition_cols
- right_columns = [right_tsdf.ts_col] + right_nonpartition_cols
-
- # Union both dataframes, and create a combined TS column
- combined_ts_col = "combined_ts"
- combined_df = left_tsdf.__addColumnsFromOtherDF(right_columns).__combineTSDF(
- right_tsdf.__addColumnsFromOtherDF(left_columns), combined_ts_col
- )
- combined_df.df = combined_df.df.withColumn(
- "rec_ind", f.when(f.col(left_tsdf.ts_col).isNotNull(), 1).otherwise(-1)
- )
+ def withColumnRenamed(self, existing: str, new: str) -> "TSDF":
+ """
+ Returns a new :class:`TSDF` with the given column renamed.
- # perform asof join.
- if tsPartitionVal is None:
- seq_col = None
- if isinstance(combined_df.ts_index, SubSequenceTSIndex):
- seq_col = combined_df.ts_index.sub_seq_col
- asofDF = combined_df.__getLastRightRow(
- left_tsdf.ts_col,
- right_columns,
- seq_col,
- tsPartitionVal,
- skipNulls,
- suppress_null_warning,
- )
- else:
- tsPartitionDF = combined_df.__getTimePartitions(
- tsPartitionVal, fraction=fraction
- )
- seq_col = None
- if isinstance(tsPartitionDF.ts_index, SubSequenceTSIndex):
- seq_col = tsPartitionDF.ts_index.sub_seq_col
- asofDF = tsPartitionDF.__getLastRightRow(
- left_tsdf.ts_col,
- right_columns,
- seq_col,
- tsPartitionVal,
- skipNulls,
- suppress_null_warning,
- )
+ :param existing: name of the existing column to renmame
+ :param new: new name for the column
+ """
- # Get rid of overlapped data and the extra columns generated from timePartitions
- df = asofDF.df.filter(f.col("is_original") == 1).drop(
- "ts_partition", "is_original"
- )
+ # create new TSIndex
+ new_ts_index = deepcopy(self.ts_index)
+ if existing == self.ts_index.name:
+ new_ts_index = new_ts_index.renamed(new)
- asofDF = TSDF(df, ts_col=asofDF.ts_col, series_ids=combined_df.series_ids)
+ # and for series ids
+ new_series_ids = self.series_ids
+ if existing in self.series_ids:
+ # replace column name in series
+ new_series_ids = self.series_ids
+ new_series_ids[new_series_ids.index(existing)] = new
- return asofDF
+ # rename the column in the underlying DF
+ new_df = self.df.withColumnRenamed(existing,new)
- def __baseWindow(self, reverse=False):
- # The index will determine the appropriate sort order
- w = Window().orderBy(self.ts_index.orderByExpr(reverse))
+ # return new TSDF
+ new_schema = TSSchema(new_ts_index, new_series_ids)
+ return TSDF(new_df, ts_schema=new_schema)
- # and partitioned by any series IDs
- if self.series_ids:
- w = w.partitionBy([f.col(sid) for sid in self.series_ids])
- return w
+ def union(self, other: TSDF) -> TSDF:
+ # union of the underlying DataFrames
+ union_df = self.df.union(other.df)
+ return self.__withTransformedDF(union_df)
- def __rowsBetweenWindow(self, rows_from, rows_to, reverse=False):
- return self.__baseWindow(reverse=reverse).rowsBetween(rows_from, rows_to)
+ def unionByName(self, other: TSDF, allowMissingColumns: bool = False) -> TSDF:
+ # union of the underlying DataFrames
+ union_df = self.df.unionByName(other.df, allowMissingColumns=allowMissingColumns)
+ return self.__withTransformedDF(union_df)
- def __rangeBetweenWindow(self, range_from, range_to, reverse=False):
- return ( self.__baseWindow(reverse=reverse)
- .orderBy(self.ts_index.rangeOrderByExpr(reverse=reverse))
- .rangeBetween(range_from, range_to ) )
+ #
+ # utility functions
+ #
def vwap(self, frequency="m", volume_col="volume", price_col="price"):
# set pre_vwap as self or enrich with the frequency
@@ -931,33 +1103,33 @@ def vwap(self, frequency="m", volume_col="volume", price_col="price"):
if frequency == "m":
pre_vwap = self.df.withColumn(
"time_group",
- f.concat(
- f.lpad(f.hour(f.col(self.ts_col)), 2, "0"),
- f.lit(":"),
- f.lpad(f.minute(f.col(self.ts_col)), 2, "0"),
+ Fn.concat(
+ Fn.lpad(Fn.hour(Fn.col(self.ts_col)), 2, "0"),
+ Fn.lit(":"),
+ Fn.lpad(Fn.minute(Fn.col(self.ts_col)), 2, "0"),
),
)
elif frequency == "H":
pre_vwap = self.df.withColumn(
- "time_group", f.concat(f.lpad(f.hour(f.col(self.ts_col)), 2, "0"))
+ "time_group", Fn.concat(Fn.lpad(Fn.hour(Fn.col(self.ts_col)), 2, "0"))
)
elif frequency == "D":
pre_vwap = self.df.withColumn(
- "time_group", f.concat(f.lpad(f.day(f.col(self.ts_col)), 2, "0"))
+ "time_group", Fn.concat(Fn.lpad(Fn.day(Fn.col(self.ts_col)), 2, "0"))
)
group_cols = ["time_group"]
if self.series_ids:
group_cols.extend(self.series_ids)
vwapped = (
- pre_vwap.withColumn("dllr_value", f.col(price_col) * f.col(volume_col))
+ pre_vwap.withColumn("dllr_value", Fn.col(price_col) * Fn.col(volume_col))
.groupby(group_cols)
.agg(
sum("dllr_value").alias("dllr_value"),
sum(volume_col).alias(volume_col),
max(price_col).alias("_".join(["max", price_col])),
)
- .withColumn("vwap", f.col("dllr_value") / f.col(volume_col))
+ .withColumn("vwap", Fn.col("dllr_value") / Fn.col(volume_col))
)
return self.__withTransformedDF(vwapped)
@@ -971,18 +1143,18 @@ def EMA(self, colName, window=30, exp_factor=0.2):
"""
emaColName = "_".join(["EMA", colName])
- df = self.df.withColumn(emaColName, f.lit(0)).orderBy(self.ts_col)
+ df = self.df.withColumn(emaColName, Fn.lit(0)).orderBy(self.ts_col)
w = self.__baseWindow()
# Generate all the lag columns:
for i in range(window):
lagColName = "_".join(["lag", colName, str(i)])
weight = exp_factor * (1 - exp_factor) ** i
- df = df.withColumn(lagColName, weight * f.lag(f.col(colName), i).over(w))
+ df = df.withColumn(lagColName, weight * Fn.lag(Fn.col(colName), i).over(w))
df = df.withColumn(
emaColName,
- f.col(emaColName)
- + f.when(f.col(lagColName).isNull(), f.lit(0)).otherwise(
- f.col(lagColName)
+ Fn.col(emaColName)
+ + Fn.when(Fn.col(lagColName).isNull(), Fn.lit(0)).otherwise(
+ Fn.col(lagColName)
),
).drop(lagColName)
# Nulls are currently removed
@@ -1009,17 +1181,17 @@ def withLookbackFeatures(
"""
# first, join all featureCols into a single array column
tempArrayColName = "__TempArrayCol"
- feat_array_tsdf = self.df.withColumn(tempArrayColName, f.array(featureCols))
+ feat_array_tsdf = self.df.withColumn(tempArrayColName, Fn.array(featureCols))
# construct a lookback array
lookback_win = self.__rowsBetweenWindow(-lookbackWindowSize, -1)
lookback_tsdf = feat_array_tsdf.withColumn(
- featureColName, f.collect_list(f.col(tempArrayColName)).over(lookback_win)
+ featureColName, Fn.collect_list(Fn.col(tempArrayColName)).over(lookback_win)
).drop(tempArrayColName)
# make sure only windows of exact size are allowed
if exactSize:
- return lookback_tsdf.where(f.size(featureColName) == lookbackWindowSize)
+ return lookback_tsdf.where(Fn.size(featureColName) == lookbackWindowSize)
return self.__withTransformedDF(lookback_tsdf)
@@ -1052,16 +1224,16 @@ def withRangeStats(
selectedCols = self.df.columns
derivedCols = []
for metric in colsToSummarize:
- selectedCols.append(f.mean(metric).over(w).alias("mean_" + metric))
- selectedCols.append(f.count(metric).over(w).alias("count_" + metric))
- selectedCols.append(f.min(metric).over(w).alias("min_" + metric))
- selectedCols.append(f.max(metric).over(w).alias("max_" + metric))
- selectedCols.append(f.sum(metric).over(w).alias("sum_" + metric))
- selectedCols.append(f.stddev(metric).over(w).alias("stddev_" + metric))
+ selectedCols.append(Fn.mean(metric).over(w).alias("mean_" + metric))
+ selectedCols.append(Fn.count(metric).over(w).alias("count_" + metric))
+ selectedCols.append(Fn.min(metric).over(w).alias("min_" + metric))
+ selectedCols.append(Fn.max(metric).over(w).alias("max_" + metric))
+ selectedCols.append(Fn.sum(metric).over(w).alias("sum_" + metric))
+ selectedCols.append(Fn.stddev(metric).over(w).alias("stddev_" + metric))
derivedCols.append(
(
- (f.col(metric) - f.col("mean_" + metric))
- / f.col("stddev_" + metric)
+ (Fn.col(metric) - Fn.col("mean_" + metric))
+ / Fn.col("stddev_" + metric)
).alias("zscore_" + metric)
)
selected_df = self.df.select(*selectedCols)
@@ -1102,8 +1274,8 @@ def withGroupedStats(self, metricCols=[], freq=None):
# build window
parsed_freq = rs.checkAllowableFreq(freq)
- agg_window = f.window(
- f.col(self.ts_col),
+ agg_window = Fn.window(
+ Fn.col(self.ts_col),
"{} {}".format(parsed_freq[0], rs.freq_dict[parsed_freq[1]]),
)
@@ -1112,19 +1284,19 @@ def withGroupedStats(self, metricCols=[], freq=None):
for metric in metricCols:
selectedCols.extend(
[
- f.mean(f.col(metric)).alias("mean_" + metric),
- f.count(f.col(metric)).alias("count_" + metric),
- f.min(f.col(metric)).alias("min_" + metric),
- f.max(f.col(metric)).alias("max_" + metric),
- f.sum(f.col(metric)).alias("sum_" + metric),
- f.stddev(f.col(metric)).alias("stddev_" + metric),
+ Fn.mean(Fn.col(metric)).alias("mean_" + metric),
+ Fn.count(Fn.col(metric)).alias("count_" + metric),
+ Fn.min(Fn.col(metric)).alias("min_" + metric),
+ Fn.max(Fn.col(metric)).alias("max_" + metric),
+ Fn.sum(Fn.col(metric)).alias("sum_" + metric),
+ Fn.stddev(Fn.col(metric)).alias("stddev_" + metric),
]
)
selected_df = self.df.groupBy(self.series_ids + [agg_window]).agg(*selectedCols)
summary_df = (
selected_df.select(*selected_df.columns)
- .withColumn(self.ts_col, f.col("window").start)
+ .withColumn(self.ts_col, Fn.col("window").start)
.drop("window")
)
@@ -1290,11 +1462,11 @@ def tempo_fourier_util(pdf):
data = self.df
if self.series_ids == []:
- data = data.withColumn("dummy_group", f.lit("dummy_val"))
+ data = data.withColumn("dummy_group", Fn.lit("dummy_val"))
data = (
- data.select(f.col("dummy_group"), self.ts_col, f.col(valueCol))
- .withColumn("tdval", f.col(valueCol))
- .withColumn("tpoints", f.col(self.ts_col))
+ data.select(Fn.col("dummy_group"), self.ts_col, Fn.col(valueCol))
+ .withColumn("tdval", Fn.col(valueCol))
+ .withColumn("tpoints", Fn.col(self.ts_col))
)
return_schema = ",".join(
[f"{i[0]} {i[1]}" for i in data.dtypes]
@@ -1307,9 +1479,9 @@ def tempo_fourier_util(pdf):
else:
group_cols = self.series_ids
data = (
- data.select(*group_cols, self.ts_col, f.col(valueCol))
- .withColumn("tdval", f.col(valueCol))
- .withColumn("tpoints", f.col(self.ts_col))
+ data.select(*group_cols, self.ts_col, Fn.col(valueCol))
+ .withColumn("tdval", Fn.col(valueCol))
+ .withColumn("tpoints", Fn.col(self.ts_col))
)
return_schema = ",".join(
[f"{i[0]} {i[1]}" for i in data.dtypes]
@@ -1348,7 +1520,7 @@ def extractStateIntervals(
# https://spark.apache.org/docs/latest/sql-ref-null-semantics.html#comparison-operators-
def null_safe_equals(col1: Column, col2: Column) -> Column:
return (
- f.when(col1.isNull() & col2.isNull(), True)
+ Fn.when(col1.isNull() & col2.isNull(), True)
.when(col1.isNull() | col2.isNull(), False)
.otherwise(operator.eq(col1, col2))
)
@@ -1400,7 +1572,7 @@ def state_comparison_fn(a, b):
# Get previous timestamp to identify start time of the interval
data = data.withColumn(
"previous_ts",
- f.lag(f.col(self.ts_col), offset=1).over(w),
+ Fn.lag(Fn.col(self.ts_col), offset=1).over(w),
)
# Determine state intervals using user-provided the state comparison function
@@ -1410,31 +1582,31 @@ def state_comparison_fn(a, b):
temp_metric_compare_col = f"__{mc}_compare"
data = data.withColumn(
temp_metric_compare_col,
- state_comparison_fn(f.col(mc), f.lag(f.col(mc), 1).over(w)),
+ state_comparison_fn(Fn.col(mc), Fn.lag(Fn.col(mc), 1).over(w)),
)
temp_metric_compare_cols.append(temp_metric_compare_col)
# Remove first record which will have no state change
# and produces `null` for all state comparisons
- data = data.filter(f.col("previous_ts").isNotNull())
+ data = data.filter(Fn.col("previous_ts").isNotNull())
# Each state comparison should return True if state remained constant
data = data.withColumn(
- "state_change", f.array_contains(f.array(*temp_metric_compare_cols), False)
+ "state_change", Fn.array_contains(Fn.array(*temp_metric_compare_cols), False)
)
# Count the distinct state changes to get the unique intervals
data = data.withColumn(
"state_incrementer",
- f.sum(f.col("state_change").cast("int")).over(w),
- ).filter(~f.col("state_change"))
+ Fn.sum(Fn.col("state_change").cast("int")).over(w),
+ ).filter(~Fn.col("state_change"))
# Find the start and end timestamp of the interval
result = (
data.groupBy(*self.series_ids, "state_incrementer")
.agg(
- f.min("previous_ts").alias("start_ts"),
- f.max(self.ts_col).alias("end_ts"),
+ Fn.min("previous_ts").alias("start_ts"),
+ Fn.max(self.ts_col).alias("end_ts"),
)
.drop("state_incrementer")
)
diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py
index 4fa1f091..a93e93f6 100644
--- a/python/tempo/tsschema.py
+++ b/python/tempo/tsschema.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import Union, Collection, List
+from typing import Any, Union, Collection, List, Iterable
from pyspark.sql import Column
import pyspark.sql.functions as Fn
@@ -15,6 +15,25 @@ class TSIndex(ABC):
Abstract base class for all Timeseries Index types
"""
+ def __eq__(self, o: object) -> bool:
+ # must be a SimpleTSIndex
+ if not isinstance(o, TSIndex):
+ return False
+ return self.indexAttributes == o.indexAttributes
+
+ def __repr__(self) -> str:
+ return self.__str__()
+
+ def __str__(self) -> str:
+ return f"""{self.__class__.__name__}({self.indexAttributes})"""
+
+ @property
+ @abstractmethod
+ def indexAttributes(self) -> dict[str,Any]:
+ """
+ :return: key attributes of this index
+ """
+
@property
@abstractmethod
def name(self) -> str:
@@ -29,6 +48,16 @@ def ts_col(self) -> str:
:return: the name of the primary timeseries column (may or may not be the same as the name)
"""
+ @abstractmethod
+ def renamed(self, new_name: str) -> "TSIndex":
+ """
+ Renames the index
+
+ :param new_name: new name of the index
+
+ :return: a copy of this :class:`TSIndex` object with the new name
+ """
+
def _reverseOrNot(self, expr: Union[Column, List[Column]], reverse: bool) -> Union[Column, List[Column]]:
if not reverse:
return expr # just return the expression as-is if we're not reversing
@@ -72,6 +101,11 @@ def __init__(self, ts_col: StructField) -> None:
self.__name = ts_col.name
self.dataType = ts_col.dataType
+ @property
+ def indexAttributes(self) -> dict[str, Any]:
+ return { 'name': self.name,
+ 'dataType': self.dataType }
+
@property
def name(self):
return self.__name
@@ -80,6 +114,10 @@ def name(self):
def ts_col(self) -> str:
return self.name
+ def renamed(self, new_name: str) -> "TSIndex":
+ self.__name = new_name
+ return self
+
def orderByExpr(self, reverse: bool = False) -> Column:
expr = Fn.col(self.name)
return self._reverseOrNot(expr, reverse)
@@ -152,14 +190,20 @@ class CompositeTSIndex(TSIndex, ABC):
def __init__(self, composite_ts_idx: StructField, primary_ts_col: str) -> None:
if not isinstance(composite_ts_idx.dataType, StructType):
raise TypeError(f"CompoundTSIndex must be of type StructType, but given compound_ts_idx {composite_ts_idx.name} has type {composite_ts_idx.dataType}")
- self.ts_idx: str = composite_ts_idx.name
+ self.__name: str = composite_ts_idx.name
self.struct: StructType = composite_ts_idx.dataType
# construct a simple TS index object for the primary column
self.primary_ts_idx: SimpleTSIndex = SimpleTSIndex.fromTSCol(self.struct[primary_ts_col])
+ @property
+ def indexAttributes(self) -> dict[str, Any]:
+ return { 'name': self.name,
+ 'struct': self.struct,
+ 'primary_ts_col': self.primary_ts_idx }
+
@property
def name(self) -> str:
- return self.ts_idx
+ return self.__name
@property
def ts_col(self) -> str:
@@ -169,6 +213,10 @@ def ts_col(self) -> str:
def primary_ts_col(self) -> str:
return self.component(self.primary_ts_idx.name)
+ def renamed(self, new_name: str) -> "TSIndex":
+ self.__name = new_name
+ return self
+
def component(self, component_name):
"""
Returns the full path to a component column that is within the composite index
@@ -196,6 +244,12 @@ def __init__(self, composite_ts_idx: StructField, primary_ts_col: str, sub_seq_c
# construct a simple index for the sub-sequence column
self.sub_sequence_idx = NumericIndex(self.struct[sub_seq_col])
+ @property
+ def indexAttributes(self) -> dict[str, Any]:
+ attrs = super().indexAttributes
+ attrs['sub_sequence_idx'] = self.sub_sequence_idx
+ return attrs
+
@property
def sub_seq_col(self) -> str:
return self.component(self.sub_sequence_idx.name)
@@ -219,6 +273,12 @@ def __init__(self, composite_ts_idx: StructField, src_str_col: str, parsed_col:
raise TypeError(f"Source string column must be of StringType, but given column {src_str_field.name} is of type {src_str_field.dataType}")
self.__src_str_col = src_str_col
+ @property
+ def indexAttributes(self) -> dict[str, Any]:
+ attrs = super().indexAttributes
+ attrs['src_str_col'] = self.src_str_col
+ return attrs
+
@property
def src_str_col(self):
return self.component(self.__src_str_col)
@@ -285,7 +345,27 @@ def __init__(
if series_ids:
self.series_ids = list(series_ids)
else:
- self.series_ids = None
+ self.series_ids = []
+
+ def __eq__(self, o: object) -> bool:
+ # must be of TSSchema type
+ if not isinstance(o, TSSchema):
+ return False
+ # must have same TSIndex
+ if self.ts_idx != o.ts_idx:
+ return False
+ # must have the same series IDs
+ if self.series_ids != o.series_ids:
+ return False
+ return True
+
+ def __repr__(self) -> str:
+ return self.__str__()
+
+ def __str__(self) -> str:
+ return f"""TSSchema({id(self)})
+ TSIndex: {self.ts_idx}
+ Series IDs: {self.series_ids}"""
@classmethod
def fromDFSchema(
@@ -296,22 +376,20 @@ def fromDFSchema(
return cls(ts_idx, series_ids)
@property
- def structural_columns(self) -> set[str]:
+ def structural_columns(self) -> list[str]:
"""
Structural columns are those that define the structure of the :class:`TSDF`. This includes the timeseries column,
a timeseries index (if different), any subsequence column (if present), and the series ID columns.
:return: a set of column names corresponding the structural columns of a :class:`TSDF`
"""
- struct_cols = {self.ts_idx.name}.union(self.series_ids)
- struct_cols.discard(None)
- return struct_cols
+ return list({self.ts_idx.name}.union(self.series_ids))
def validate(self, df_schema: StructType) -> None:
pass
- def find_observational_columns(self, df_schema: StructType) -> set[str]:
- return set(df_schema.fieldNames()) - self.structural_columns
+ def find_observational_columns(self, df_schema: StructType) -> list[str]:
+ return list(set(df_schema.fieldNames()) - set(self.structural_columns))
def find_metric_columns(self, df_schema: StructType) -> list[str]:
return [
diff --git a/python/tests/as_of_join_tests.py b/python/tests/as_of_join_tests.py
index 1ca2b77d..17ef8b32 100644
--- a/python/tests/as_of_join_tests.py
+++ b/python/tests/as_of_join_tests.py
@@ -14,10 +14,10 @@ def test_asof_join(self):
noRightPrefixdfExpected = self.get_data_as_sdf("expected_no_right_prefix")
# perform the join
- joined_df = tsdf_left.asofJoin(
+ joined_df = tsdf_left.asOfJoin(
tsdf_right, left_prefix="left", right_prefix="right"
).df
- non_prefix_joined_df = tsdf_left.asofJoin(
+ non_prefix_joined_df = tsdf_left.asOfJoin(
tsdf_right, left_prefix="left", right_prefix=""
).df
@@ -25,7 +25,7 @@ def test_asof_join(self):
self.assertDataFrameEquality(joined_df, dfExpected)
self.assertDataFrameEquality(non_prefix_joined_df, noRightPrefixdfExpected)
- spark_sql_joined_df = tsdf_left.asofJoin(
+ spark_sql_joined_df = tsdf_left.asOfJoin(
tsdf_right, left_prefix="left", right_prefix="right"
).df
self.assertDataFrameEquality(spark_sql_joined_df, dfExpected)
@@ -42,7 +42,7 @@ def test_asof_join_skip_nulls_disabled(self):
)
# perform the join with skip nulls enabled (default)
- joined_df = tsdf_left.asofJoin(
+ joined_df = tsdf_left.asOfJoin(
tsdf_right, left_prefix="left", right_prefix="right"
).df
@@ -50,7 +50,7 @@ def test_asof_join_skip_nulls_disabled(self):
self.assertDataFrameEquality(joined_df, dfExpectedSkipNulls)
# perform the join with skip nulls disabled
- joined_df = tsdf_left.asofJoin(
+ joined_df = tsdf_left.asOfJoin(
tsdf_right, left_prefix="left", right_prefix="right", skipNulls=False
).df
@@ -66,10 +66,10 @@ def test_sequence_number_sort(self):
dfExpected = self.get_data_as_sdf("expected")
# perform the join
- joined_df = tsdf_left.asofJoin(tsdf_right, right_prefix="right").df
+ joined_df = tsdf_left.asOfJoin(tsdf_right, right_prefix="right")
# joined dataframe should equal the expected dataframe
- self.assertDataFrameEquality(joined_df, dfExpected)
+ self.assertDataFrameEquality(joined_df.df, dfExpected)
def test_partitioned_asof_join(self):
"""AS-OF Join with a time-partition"""
@@ -79,7 +79,7 @@ def test_partitioned_asof_join(self):
tsdf_right = self.get_data_as_tsdf("right")
dfExpected = self.get_data_as_sdf("expected")
- joined_df = tsdf_left.asofJoin(
+ joined_df = tsdf_left.asOfJoin(
tsdf_right,
left_prefix="left",
right_prefix="right",
@@ -98,7 +98,7 @@ def test_asof_join_nanos(self):
dfExpected = self.get_data_as_sdf("expected")
# perform join
- joined_df = tsdf_left.asofJoin(
+ joined_df = tsdf_left.asOfJoin(
tsdf_right, left_prefix="left", right_prefix="right"
).df
diff --git a/python/tests/base.py b/python/tests/base.py
index ed8548ec..84b24201 100644
--- a/python/tests/base.py
+++ b/python/tests/base.py
@@ -144,7 +144,14 @@ def buildTestDF(self, schema, data, ts_cols=["event_ts"]):
# convert timstamp fields to timestamp type
for tsc in ts_cols:
- df = df.withColumn(tsc, F.to_timestamp(F.col(tsc)))
+ # check if the column is nested in a struct or not
+ if '.' in tsc:
+ # we're changing a field nested in a struct
+ (struct, field) = tsc.split('.')
+ df = df.withColumn(struct, F.col(struct).withField(field, F.to_timestamp(tsc)))
+ else:
+ # standard column
+ df = df.withColumn(tsc, F.to_timestamp(F.col(tsc)))
return df
#
diff --git a/python/tests/unit_test_data/as_of_join_tests.json b/python/tests/unit_test_data/as_of_join_tests.json
index f68be584..ff1c0595 100644
--- a/python/tests/unit_test_data/as_of_join_tests.json
+++ b/python/tests/unit_test_data/as_of_join_tests.json
@@ -121,16 +121,16 @@
]
},
"expected": {
- "schema": "symbol string, event_ts string, trade_pr float, trade_id int, right_event_ts string, right_bid_pr float, right_ask_pr float, right_seq_nb long",
+ "schema": "symbol string, combined_ts struct, event_ts: string, trade_pr float, trade_id int, right_ts_idx struct, right_bid_pr float, right_ask_pr float",
"ts_col": "event_ts",
"series_ids": ["symbol"],
- "other_ts_cols": ["right_event_ts"],
+ "other_ts_cols": ["combined_ts.event_ts", "right_ts_idx.event_ts"],
"data": [
- ["S1", "2020-08-01 00:00:10", 349.21, 1, "2020-08-01 00:00:10", 19.11, 20.12, 1],
- ["S1", "2020-08-01 00:00:10", 350.21, 5, "2020-08-01 00:00:10", 19.11, 20.12, 1],
- ["S1", "2020-08-01 00:01:12", 351.32, 2, "2020-08-01 00:01:05", 348.10, 1000.13, 3],
- ["S1", "2020-09-01 00:02:10", 361.1, 3, "2020-09-01 00:02:01", 358.93, 365.12, 4],
- ["S1", "2020-09-01 00:19:12", 362.1, 4, "2020-09-01 00:15:01", 359.21, 365.31, 5]
+ ["S1", ["2020-08-01 00:00:10", 1], "2020-08-01 00:00:10", 349.21, 1, ["2020-08-01 00:00:10", 1], 19.11, 20.12],
+ ["S1", ["2020-08-01 00:00:10", 1], "2020-08-01 00:00:10", 350.21, 5, ["2020-08-01 00:00:10", 1], 19.11, 20.12],
+ ["S1", ["2020-08-01 00:01:12", 3], "2020-08-01 00:01:12", 351.32, 2, ["2020-08-01 00:01:05", 3], 348.10, 1000.13],
+ ["S1", ["2020-09-01 00:02:10", 4], "2020-09-01 00:02:10", 361.1, 3, ["2020-09-01 00:02:01", 4], 358.93, 365.12],
+ ["S1", ["2020-09-01 00:19:12", 5], "2020-09-01 00:19:12", 362.1, 4, ["2020-09-01 00:15:01", 5], 359.21, 365.31]
]
}
},
From ef1f4ee296158afe74ced177465a8bff064aa361 Mon Sep 17 00:00:00 2001
From: Tristan Nixon
Date: Wed, 14 Sep 2022 15:34:07 -0700
Subject: [PATCH 07/11] Revert "checkpoint save of current progress..."
This reverts commit c2ef72b431d7c460b8d2d6046c2875abd838becb.
---
examples/financial_services_quickstart.py | 4 +-
python/README.md | 16 +-
python/tempo/tsdf.py | 940 +++++++-----------
python/tempo/tsschema.py | 98 +-
python/tests/as_of_join_tests.py | 18 +-
python/tests/base.py | 9 +-
.../unit_test_data/as_of_join_tests.json | 14 +-
7 files changed, 419 insertions(+), 680 deletions(-)
diff --git a/examples/financial_services_quickstart.py b/examples/financial_services_quickstart.py
index 7cfd9064..d515bd97 100644
--- a/examples/financial_services_quickstart.py
+++ b/examples/financial_services_quickstart.py
@@ -125,7 +125,7 @@
# COMMAND ----------
# DBTITLE 1,AS OF Joins - Get Latest Quote Information As Of Time of Trades
-joined_df = trades_tsdf.asOfJoin(quotes_tsdf, right_prefix="quote_asof").df
+joined_df = trades_tsdf.asofJoin(quotes_tsdf, right_prefix="quote_asof").df
display(joined_df.filter(col("symbol") == 'AMH').filter(col("quote_asof_event_ts").isNotNull()))
@@ -136,7 +136,7 @@
logging.getLogger("py4j").setLevel(logging.WARNING)
logging.getLogger("tempo").setLevel(logging.WARNING)
-joined_df = trades_tsdf.asOfJoin(quotes_tsdf, tsPartitionVal=30, right_prefix="quote_asof").df
+joined_df = trades_tsdf.asofJoin(quotes_tsdf, tsPartitionVal=30,right_prefix="quote_asof").df
display(joined_df)
# COMMAND ----------
diff --git a/python/README.md b/python/README.md
index 4e1a532b..01ab82eb 100644
--- a/python/README.md
+++ b/python/README.md
@@ -91,20 +91,16 @@ fig.show()
+
```python
-from pyspark.sql.functions import *
+from pyspark.sql.functions import *
-watch_accel_df = spark.read.format("csv").option("header", "true").load(
- "dbfs:/home/tempo/Watch_accelerometer").withColumn("event_ts", (col("Arrival_Time").cast("double") / 1000).cast(
- "timestamp")).withColumn("x", col("x").cast("double")).withColumn("y", col("y").cast("double")).withColumn("z",
- col("z").cast(
- "double")).withColumn(
- "event_ts_dbl", col("event_ts").cast("double"))
+watch_accel_df = spark.read.format("csv").option("header", "true").load("dbfs:/home/tempo/Watch_accelerometer").withColumn("event_ts", (col("Arrival_Time").cast("double")/1000).cast("timestamp")).withColumn("x", col("x").cast("double")).withColumn("y", col("y").cast("double")).withColumn("z", col("z").cast("double")).withColumn("event_ts_dbl", col("event_ts").cast("double"))
-watch_accel_tsdf = TSDF(watch_accel_df, ts_col="event_ts", series_ids=["User"])
+watch_accel_tsdf = TSDF(watch_accel_df, ts_col="event_ts", series_ids = ["User"])
# Applying AS OF join to TSDF datasets
-joined_df = watch_accel_tsdf.asOfJoin(phone_accel_tsdf, right_prefix="phone_accel")
+joined_df = watch_accel_tsdf.asofJoin(phone_accel_tsdf, right_prefix="phone_accel")
display(joined_df)
```
@@ -122,7 +118,7 @@ fraction = overlap fraction
right_prefix = prefix used for source columns when merged into fact table
```python
-joined_df = watch_accel_tsdf.asOfJoin(phone_accel_tsdf, right_prefix="watch_accel", tsPartitionVal=10, fraction=0.1)
+joined_df = watch_accel_tsdf.asofJoin(phone_accel_tsdf, right_prefix="watch_accel", tsPartitionVal = 10, fraction = 0.1)
display(joined_df)
```
diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py
index d9e370a6..97b348b9 100644
--- a/python/tempo/tsdf.py
+++ b/python/tempo/tsdf.py
@@ -7,7 +7,7 @@
from copy import deepcopy
import numpy as np
-import pyspark.sql.functions as Fn
+import pyspark.sql.functions as f
from IPython.core.display import HTML
from IPython.display import display as ipydisplay
from pyspark.sql import SparkSession
@@ -19,7 +19,7 @@
import tempo.io as tio
import tempo.resample as rs
from tempo.interpol import Interpolation
-from tempo.tsschema import TSIndex, TSSchema, SubSequenceTSIndex, SimpleTSIndex
+from tempo.tsschema import TSIndex, TSSchema, SubSequenceTSIndex
from tempo.utils import (
ENV_CAN_RENDER_HTML,
IS_DATABRICKS,
@@ -29,29 +29,6 @@
logger = logging.getLogger(__name__)
-class TSDFStructureChangeError(Exception):
- """
- Error raised when a user attempts an operation that would alter the structure of a TSDF in a destructive manner.
- """
- __MSG_TEMPLATE: str = """
- The attempted operation ({op}) is not allowed because it would result in altering the structure of the TSDF.
- If you really want to make this change, perform the operation on the underlying DataFrame, then re-create a new TSDF.
- {d}"""
-
- def __init__(self, operation: str, details: str = None) -> None:
- super().__init__(self.__MSG_TEMPLATE.format(op=operation, d=details))
-
-
-class IncompatibleTSError(Exception):
- """
- Error raised when an operation is attempted between two incompatible TSDFs.
- """
- __MSG_TEMPLATE: str = """
- The attempted operation ({op}) cannot be performed because the given TSDFs have incompatible structure.
- {d}"""
-
- def __init__(self, operation: str, details: str = None) -> None:
- super().__init__(self.__MSG_TEMPLATE.format(op=operation, d=details))
class TSDF:
"""
@@ -64,7 +41,7 @@ def __init__(
ts_schema: TSSchema = None,
ts_col: str = None,
series_ids: Collection[str] = None,
- validate_schema=True
+ validate_schema=True,
) -> None:
self.df = df
# construct schema if we don't already have one
@@ -76,16 +53,6 @@ def __init__(
if validate_schema:
self.ts_schema.validate(df.schema)
- def __repr__(self) -> str:
- return self.__str__()
-
- def __str__(self) -> str:
- return f"""TSDF({id(self)}):
- TS Index: {self.ts_index}
- Series IDs: {self.series_ids}
- Observational Cols: {self.observational_cols}
- DataFrame: {self.df.schema}"""
-
def __withTransformedDF(self, new_df: DataFrame) -> "TSDF":
"""
This helper function will create a new :class:`TSDF` using the current schema, but a new / transformed :class:`DataFrame`
@@ -96,19 +63,6 @@ def __withTransformedDF(self, new_df: DataFrame) -> "TSDF":
"""
return TSDF(new_df, ts_schema=deepcopy(self.ts_schema), validate_schema=False)
- def __withStandardizedColOrder(self) -> TSDF:
- """
- Standardizes the column ordering as such:
- * series_ids,
- * ts_index,
- * observation columns
-
- :return: a :class:`TSDF` with the columns reordered into "standard order" (as described above)
- """
- std_ordered_cols = list(self.series_ids) + [self.ts_index.name] + list(self.observational_cols)
-
- return self.__withTransformedDF(self.df.select(std_ordered_cols))
-
@classmethod
def __makeStructFromCols(cls, df: DataFrame, struct_col_name: str, cols_to_move: List[str]) -> DataFrame:
"""
@@ -120,9 +74,8 @@ def __makeStructFromCols(cls, df: DataFrame, struct_col_name: str, cols_to_move:
:return: the transformed :class:`DataFrame`
"""
- return df.withColumn(struct_col_name, Fn.struct(cols_to_move)).drop(*cols_to_move)
+ return df.withColumn(struct_col_name, f.struct(cols_to_move)).drop(*cols_to_move)
- # default column name for constructed timeseries index struct columns
__DEFAULT_TS_IDX_COL = "ts_idx"
@classmethod
@@ -153,26 +106,46 @@ def ts_index(self) -> "TSIndex":
def ts_col(self) -> str:
return self.ts_index.ts_col
- @property
- def columns(self) -> List[str]:
- return self.df.columns
-
@property
def series_ids(self) -> List[str]:
return self.ts_schema.series_ids
@property
- def structural_cols(self) -> List[str]:
+ def structural_cols(self) -> Set[str]:
return self.ts_schema.structural_columns
@property
def observational_cols(self) -> List[str]:
- return self.ts_schema.find_observational_columns(self.df.schema)
+ return list(self.ts_schema.find_observational_columns(self.df.schema))
@property
def metric_cols(self) -> List[str]:
return self.ts_schema.find_metric_columns(self.df.schema)
+ # def __init__(self, df, ts_col="event_ts", partition_cols=None, sequence_col=None):
+ # """
+ # Constructor
+ # :param df:
+ # :param ts_col:
+ # :param partitionCols:
+ # :sequence_col every tsdf allows for a tie-breaker secondary sort key
+ # """
+ # self.ts_col = self.__validated_column(df, ts_col)
+ # self.partitionCols = (
+ # []
+ # if partition_cols is None
+ # else self.__validated_columns(df, partition_cols.copy())
+ # )
+ #
+ # self.df = df
+ # self.sequence_col = "" if sequence_col is None else sequence_col
+ #
+ # # Add customized check for string type for the timestamp. If we see a string, we will proactively created a double version of the string timestamp for sorting purposes and rename to ts_col
+ # if df.schema[ts_col].dataType == "StringType":
+ # sample_ts = df.limit(1).collect()[0][0]
+ # self.__validate_ts_string(sample_ts)
+ # self.__add_double_ts().withColumnRenamed("double_ts", self.ts_col)
+
#
# Helper functions
#
@@ -183,14 +156,14 @@ def __add_double_ts(self):
self.df.withColumn(
"nanos",
(
- Fn.when(
- Fn.col(self.ts_col).contains("."),
- Fn.concat(Fn.lit("0."), Fn.split(Fn.col(self.ts_col), "\.")[1]),
+ f.when(
+ f.col(self.ts_col).contains("."),
+ f.concat(f.lit("0."), f.split(f.col(self.ts_col), "\.")[1]),
).otherwise(0)
).cast("double"),
)
- .withColumn("long_ts", Fn.col(self.ts_col).cast("timestamp").cast("long"))
- .withColumn("double_ts", Fn.col("long_ts") + Fn.col("nanos"))
+ .withColumn("long_ts", f.col(self.ts_col).cast("timestamp").cast("long"))
+ .withColumn("double_ts", f.col("long_ts") + f.col("nanos"))
.drop("nanos")
.drop("long_ts")
)
@@ -230,151 +203,68 @@ def __validated_columns(self, df, colnames):
self.__validated_column(df, col)
return colnames
- #
- # As-Of Join and associated helper functions
- #
-
- def __hasSameSeriesIDs(self, tsdf_right: TSDF):
+ def __checkPartitionCols(self, tsdf_right):
for left_col, right_col in zip(self.series_ids, tsdf_right.series_ids):
if left_col != right_col:
raise ValueError(
- "left and right dataframes must have the same series ID columns, in the same order"
+ "left and right dataframe partition columns should have same name in same order"
)
- def __validateTsColMatch(self, right_tsdf: TSDF):
+ def __validateTsColMatch(self, right_tsdf):
+ # TODO - can simplify this to get types from schema object
left_ts_datatype = self.df.select(self.ts_col).dtypes[0][1]
right_ts_datatype = right_tsdf.df.select(right_tsdf.ts_col).dtypes[0][1]
if left_ts_datatype != right_ts_datatype:
raise ValueError(
- "left and right dataframes must have primary time index columns of the same type"
+ "left and right dataframe timestamp index columns should have same type"
)
- def __addPrefixToAllColumns(self, prefix: str, include_series_ids=False):
+ def __addPrefixToColumns(self, col_list, prefix):
"""
-
- :param prefix:
- :param include_series_ids:
- :return:
+ Add prefix to all specified columns.
"""
+ if prefix != "":
+ prefix = prefix + "_"
- # no-op if no prefix defined
- if not prefix or prefix == "":
- return self
-
- # find the columns to prefix
- cols_to_prefix = self.columns
- if not include_series_ids:
- cols_to_prefix = set(cols_to_prefix).difference(self.series_ids)
-
- # apply a renaming to all
- renamed_tsdf = reduce(
- lambda tsdf, col: tsdf.withColumnRenamed( col, f"{prefix}_{col}" ),
- cols_to_prefix,
- self
- ) if len(cols_to_prefix) > 0 else self
-
- return renamed_tsdf
-
- def __prefixedColumnMapping(self, col_list, prefix):
- """
- Create an old -> new column name mapping by adding a prefix to all columns in the given list
- """
+ df = reduce(
+ lambda df, idx: df.withColumnRenamed(
+ col_list[idx], "".join([prefix, col_list[idx]])
+ ),
+ range(len(col_list)),
+ self.df,
+ )
- # no-op if no prefix defined
- if not prefix or prefix == "":
- return { col : col for col in col_list }
+ if prefix == "":
+ ts_col = self.ts_col
+ else:
+ ts_col = "".join([prefix, self.ts_col])
- # otherwise add the prefix
- return { col : f"{prefix}_{col}" for col in col_list }
+ return TSDF(df, ts_col=ts_col, series_ids=self.series_ids)
- def __renameColumns(self, col_map: dict):
+ def __addColumnsFromOtherDF(self, other_cols):
"""
- renames columns in this TSDF based on the given mapping
+ Add columns from some other DF as lit(None), as pre-step before union.
"""
+ new_df = reduce(
+ lambda df, idx: df.withColumn(other_cols[idx], f.lit(None)),
+ range(len(other_cols)),
+ self.df,
+ )
- renamed_tsdf = reduce(
- lambda tsdf, colmap: tsdf.withColumnRenamed( colmap[0], colmap[1] ),
- col_map.items(),
- self
- ) if len(col_map) > 0 else self
+ return self.__withTransformedDF(new_df)
- return renamed_tsdf
+ def __combineTSDF(self, ts_df_right, combined_ts_col):
+ combined_df = self.df.unionByName(ts_df_right.df).withColumn(
+ combined_ts_col, f.coalesce(self.ts_col, ts_df_right.ts_col)
+ )
- def __addMissingColumnsFrom(self, other: TSDF) -> "TSDF":
- """
- Add missing columns from other TSDF as lit(None), as pre-step before union.
- """
- missing_cols = set(other.columns).difference(self.columns)
- new_tsdf = reduce(
- lambda tsdf, col: tsdf.withColumn(col, Fn.lit(None)),
- missing_cols,
- self,
- ) if len(missing_cols) > 0 else self
-
- return new_tsdf
-
- def __findCommonColumns(self, other: TSDF, include_series_ids = False) -> set[str]:
- common_cols = set(self.columns).intersection(other.columns)
- if include_series_ids:
- return common_cols
- return common_cols.difference(set(self.series_ids).union(other.series_ids))
-
- def __combineTSDF(self,
- right: TSDF,
- combined_ts_col: str) -> "TSDF":
- # add all columns missing from each DF
- left_padded_tsdf = self.__addMissingColumnsFrom(right)
- right_padded_tsdf = right.__addMissingColumnsFrom(self)
-
- # next, union them together,
- combined_df = left_padded_tsdf.df.unionByName(right_padded_tsdf.df)
-
- # coalesce a combined ts index
- # special-case logic if one or both of these involve a sub-sequence
- is_left_subseq = isinstance(self.ts_index, SubSequenceTSIndex)
- is_right_subseq = isinstance(right.ts_index, SubSequenceTSIndex)
- if (is_left_subseq or is_right_subseq): # at least one index has a sub-sequence
- # identify which side has the sub-sequence (or both!)
- secondary_subseq_expr = Fn.lit(None)
- if is_left_subseq:
- primary_subseq_expr = self.ts_index.sub_seq_col
- if is_right_subseq:
- secondary_subseq_expr = right.ts_index.sub_seq_col
- else:
- primary_subseq_expr = right.ts_index.sub_seq_col
- # coalesce into a new struct
- combined_ts_field = "event_ts"
- combined_subseq_field = "sub_seq"
- combined_df = combined_df.withColumn(combined_ts_col,
- Fn.struct(
- Fn.coalesce(self.ts_index.ts_col,
- right.ts_index.ts_col).alias(combined_ts_field),
- Fn.coalesce(primary_subseq_expr,
- secondary_subseq_expr).alias(combined_subseq_field)
- ))
- # construct new SubSequenceTSIndex to represent the combined column
- combined_ts_struct = combined_df.schema[combined_ts_col]
- new_ts_index = SubSequenceTSIndex( combined_ts_struct, combined_ts_field, combined_subseq_field)
- else: # no sub-sequence index, coalesce a simple TS column
- combined_df = combined_df.withColumn(combined_ts_col,
- Fn.coalesce(self.ts_col,right.ts_col))
- new_ts_index = SimpleTSIndex.fromTSCol(combined_df.schema[combined_ts_col])
-
- # finally, put the columns into a standard order
- # (series_ids, ts_col, left_cols, right_cols)
- base_cols = list(self.series_ids) + [combined_ts_col]
- left_cols = list(set(self.columns).difference(base_cols))
- right_cols = list(set(right.columns).difference(base_cols))
-
- # return it as a TSDF
- new_ts_schema = TSSchema( new_ts_index, self.series_ids )
- return TSDF( combined_df.select(base_cols + left_cols + right_cols),
- ts_schema=new_ts_schema )
+ return TSDF(combined_df, ts_col=combined_ts_col, series_ids=self.series_ids)
def __getLastRightRow(
self,
left_ts_col,
right_cols,
+ sequence_col,
tsPartitionVal,
ignoreNulls,
suppress_null_warning,
@@ -384,52 +274,36 @@ def __getLastRightRow(
self.ts_col, which is the combined time-stamp column of both left and right dataframe, is dropped at the end
since it is no longer used in subsequent methods.
"""
+ ptntl_sort_keys = [self.ts_col, "rec_ind", sequence_col]
+ sort_keys = [f.col(col_name) for col_name in ptntl_sort_keys if col_name]
- # add an indicator column where the left_ts_col might be null
- left_ts_null_indicator_col = "rec_ind"
- unreduced_tsdf = self.withColumn(left_ts_null_indicator_col,
- Fn.when(Fn.col(left_ts_col).isNotNull(), 1).otherwise(-1))
-
- # build a custom ordering expression with the indicator as *second* sort column
- # (before any other sub-sequence cols)
- order_by_expr = unreduced_tsdf.ts_index.orderByExpr()
- if isinstance(order_by_expr, Column):
- order_by_expr = [order_by_expr, Fn.col(left_ts_null_indicator_col)]
- elif isinstance(order_by_expr, list):
- order_by_expr = [ order_by_expr[0], Fn.col(left_ts_null_indicator_col) ]
- order_by_expr.extend(order_by_expr[1:])
- else:
- raise TypeError(f"Timeseries index's orderByExpr has an unknown type: {type(order_by_expr)}")
-
- unreduced_tsdf.df.orderBy(order_by_expr).show()
-
- # build our search window
window_spec = (
- Window.partitionBy(list(unreduced_tsdf.series_ids))
- .orderBy(order_by_expr)
+ Window.partitionBy(self.series_ids)
+ .orderBy(sort_keys)
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
)
if ignoreNulls is False:
if tsPartitionVal is not None:
- raise ValueError("Disabling null skipping with a partition value is not supported yet.")
-
+ raise ValueError(
+ "Disabling null skipping with a partition value is not supported yet."
+ )
df = reduce(
lambda df, idx: df.withColumn(
right_cols[idx],
- Fn.last(
- Fn.when(
- Fn.col(left_ts_null_indicator_col) == -1, Fn.struct(right_cols[idx])
+ f.last(
+ f.when(
+ f.col("rec_ind") == -1, f.struct(right_cols[idx])
).otherwise(None),
True, # ignore nulls because it indicates rows from the left side
).over(window_spec),
),
range(len(right_cols)),
- unreduced_tsdf.df,
+ self.df,
)
df = reduce(
lambda df, idx: df.withColumn(
- right_cols[idx], Fn.col(right_cols[idx])[right_cols[idx]]
+ right_cols[idx], f.col(right_cols[idx])[right_cols[idx]]
),
range(len(right_cols)),
df,
@@ -439,27 +313,27 @@ def __getLastRightRow(
df = reduce(
lambda df, idx: df.withColumn(
right_cols[idx],
- Fn.last(right_cols[idx], ignoreNulls).over(window_spec),
+ f.last(right_cols[idx], ignoreNulls).over(window_spec),
),
range(len(right_cols)),
- unreduced_tsdf.df,
+ self.df,
)
else:
df = reduce(
lambda df, idx: df.withColumn(
right_cols[idx],
- Fn.last(right_cols[idx], ignoreNulls).over(window_spec),
+ f.last(right_cols[idx], ignoreNulls).over(window_spec),
).withColumn(
"non_null_ct" + right_cols[idx],
- Fn.count(right_cols[idx]).over(window_spec),
+ f.count(right_cols[idx]).over(window_spec),
),
range(len(right_cols)),
- unreduced_tsdf.df,
+ self.df,
)
- df = (df.filter(Fn.col(left_ts_col).isNotNull())
- .drop(unreduced_tsdf.ts_col)
- .drop(left_ts_null_indicator_col))
+ df = (df.filter(f.col(left_ts_col).isNotNull()).drop(self.ts_col)).drop(
+ "rec_ind"
+ )
# remove the null_ct stats used to record missing values in partitioned as of join
if tsPartitionVal is not None:
@@ -482,7 +356,7 @@ def __getLastRightRow(
)
df = df.drop(column)
- return TSDF(df, ts_col=left_ts_col, series_ids=self.series_ids).__withStandardizedColOrder()
+ return TSDF(df, ts_col=left_ts_col, series_ids=self.series_ids)
def __getTimePartitions(self, tsPartitionVal, fraction=0.1):
"""
@@ -498,26 +372,26 @@ def __getTimePartitions(self, tsPartitionVal, fraction=0.1):
"""
partition_df = (
self.df.withColumn(
- "ts_col_double", Fn.col(self.ts_col).cast("double")
+ "ts_col_double", f.col(self.ts_col).cast("double")
) # double is preferred over unix_timestamp
.withColumn(
"ts_partition",
- Fn.lit(tsPartitionVal)
- * (Fn.col("ts_col_double") / Fn.lit(tsPartitionVal)).cast("integer"),
+ f.lit(tsPartitionVal)
+ * (f.col("ts_col_double") / f.lit(tsPartitionVal)).cast("integer"),
)
.withColumn(
"partition_remainder",
- (Fn.col("ts_col_double") - Fn.col("ts_partition"))
- / Fn.lit(tsPartitionVal),
+ (f.col("ts_col_double") - f.col("ts_partition"))
+ / f.lit(tsPartitionVal),
)
- .withColumn("is_original", Fn.lit(1))
+ .withColumn("is_original", f.lit(1))
).cache() # cache it because it's used twice.
# add [1 - fraction] of previous time partition to the next partition.
remainder_df = (
- partition_df.filter(Fn.col("partition_remainder") >= Fn.lit(1 - fraction))
- .withColumn("ts_partition", Fn.col("ts_partition") + Fn.lit(tsPartitionVal))
- .withColumn("is_original", Fn.lit(0))
+ partition_df.filter(f.col("partition_remainder") >= f.lit(1 - fraction))
+ .withColumn("ts_partition", f.col("ts_partition") + f.lit(tsPartitionVal))
+ .withColumn("is_original", f.lit(0))
)
df = partition_df.union(remainder_df).drop(
@@ -525,207 +399,6 @@ def __getTimePartitions(self, tsPartitionVal, fraction=0.1):
)
return TSDF(df, ts_col=self.ts_col, series_ids=self.series_ids + ["ts_partition"])
-
- def __getBytesFromPlan(self, df: DataFrame, spark: SparkSession):
- """
- Internal helper function to obtain how many bytes in memory the Spark data frame is likely to take up. This is an upper bound and is obtained from the plan details in Spark
-
- Parameters
- :param df - input Spark data frame - the AS OF join has 2 data frames; this will be called for each
- """
-
- df.createOrReplaceTempView("view")
- plan = spark.sql("explain cost select * from view").collect()[0][0]
-
- import re
-
- result = (
- re.search(r"sizeInBytes=.*(['\)])", plan, re.MULTILINE)
- .group(0)
- .replace(")", "")
- )
- size = result.split("=")[1].split(" ")[0]
- units = result.split("=")[1].split(" ")[1]
-
- # perform to MB for threshold check
- if units == "GiB":
- bytes = float(size) * 1024 * 1024 * 1024
- elif units == "MiB":
- bytes = float(size) * 1024 * 1024
- elif units == "KiB":
- bytes = float(size) * 1024
- else:
- bytes = float(size)
-
- return bytes
-
- def __broadcastAsOfJoin(self,
- right: TSDF,
- left_prefix: str,
- right_prefix: str) -> TSDF:
-
- # prefix all columns that share common names, except for series IDs
- common_non_series_cols = self.__findCommonColumns(right)
- left_prefixed_tsdf = self.__prefixedColumnMapping(common_non_series_cols, left_prefix)
- right_prefixed_tsdf = right.__prefixedColumnMapping(common_non_series_cols, right_prefix)
-
- # build an "upper bound" for the join on the right-hand ts column
- right_ts_col = right_prefixed_tsdf.ts_col
- upper_bound_ts_col = "upper_bound_"+ right_ts_col
- max_ts = "9999-12-31"
- w = right_prefixed_tsdf.__baseWindow()
- right_w_upper_bound = right_prefixed_tsdf.withColumn(upper_bound_ts_col,
- Fn.coalesce(
- Fn.lead(right_ts_col).over(w),
- Fn.lit(max_ts).cast("timestamp")))
-
- # perform join
- left_ts_col = left_prefixed_tsdf.ts_col
- series_ids = left_prefixed_tsdf.series_ids
- res = (
- left_prefixed_tsdf.df
- .join(right_w_upper_bound.df, list(series_ids))
- .where(left_prefixed_tsdf[left_ts_col].between(Fn.col(right_ts_col),
- Fn.col(upper_bound_ts_col)))
- .drop(upper_bound_ts_col)
- )
-
- # return as new TSDF
- return TSDF(res, ts_col=left_ts_col, series_ids=series_ids)
-
- def __skewAsOfJoin(self,
- right: TSDF,
- left_prefix: str,
- right_prefix: str,
- tsPartitionVal,
- fraction=0.1,
- skipNulls: bool = True,
- suppress_null_warning: bool = False) -> TSDF:
- logger.warning(
- "You are using the skew version of the AS OF join. "
- "This may result in null values if there are any values outside of the maximum lookback. "
- "For maximum efficiency, choose smaller values of maximum lookback, "
- "trading off performance and potential blank AS OF values for sparse keys"
- )
- # prefix all columns except for series IDs
- left_prefixed_tsdf = self.__addPrefixToAllColumns(left_prefix)
- right_prefixed_tsdf = right.__addPrefixToAllColumns(right_prefix)
-
-
- # Union both dataframes, and create a combined TS column
- combined_ts_col = "combined_ts"
- combined_tsdf = left_prefixed_tsdf.__combineTSDF(right_prefixed_tsdf, combined_ts_col)
- print(f"combined tsdf: {combined_tsdf}")
-
- # set up time partitions
- tsPartitionDF = combined_tsdf.__getTimePartitions(tsPartitionVal,
- fraction=fraction)
- print(f"tsPartitionDF: {tsPartitionDF}")
-
- # resolve correct right-hand rows
- right_cols = list(set(right_prefixed_tsdf.columns).difference(combined_tsdf.series_ids))
- asofDF = tsPartitionDF.__getLastRightRow(
- left_prefixed_tsdf.ts_col,
- right_cols,
- tsPartitionVal,
- skipNulls,
- suppress_null_warning,
- )
- print(f"asofDF: {asofDF}")
-
- # Get rid of overlapped data and the extra columns generated from timePartitions
- df = ( asofDF.df.filter(Fn.col("is_original") == 1)
- .drop("ts_partition", "is_original"))
-
- return TSDF(df, ts_col=asofDF.ts_col, series_ids=asofDF.series_ids)
-
- def __standardAsOfJoin(self,
- right: TSDF,
- left_prefix: str,
- right_prefix: str,
- skipNulls: bool = True,
- suppress_null_warning: bool = False) -> TSDF:
- # prefix all columns except for series IDs
- left_prefixed_tsdf = self.__addPrefixToAllColumns(left_prefix)
- right_prefixed_tsdf = right.__addPrefixToAllColumns(right_prefix)
-
- # Union both dataframes, and create a combined TS column
- combined_ts_col = "combined_ts"
- combined_tsdf = left_prefixed_tsdf.__combineTSDF(right_prefixed_tsdf, combined_ts_col)
-
- # resolve correct right-hand rows
- right_cols = list(set(right_prefixed_tsdf.columns).difference(combined_tsdf.series_ids))
- asofDF = combined_tsdf.__getLastRightRow(
- left_prefixed_tsdf.ts_col,
- right_cols,
- None,
- skipNulls,
- suppress_null_warning,
- )
-
- return asofDF
-
- def asOfJoin(
- self,
- right: TSDF,
- left_prefix: str = None,
- right_prefix: str = "right",
- tsPartitionVal = None,
- fraction: float = 0.5,
- skipNulls: bool = True,
- sql_join_opt: bool = False,
- suppress_null_warning: bool = False,
- ):
- """
- Performs an as-of join between two time-series. If a tsPartitionVal is specified, it will do this partitioned by
- time brackets, which can help alleviate skew.
-
- NOTE: partition cols have to be the same for both Dataframes. We are collecting stats when the WARNING level is
- enabled also.
-
- Parameters
- :param right - right-hand data frame containing columns to merge in
- :param left_prefix - optional prefix for base data frame
- :param right_prefix - optional prefix for right-hand data frame
- :param tsPartitionVal - value to partition each series into time brackets
- :param fraction - overlap fraction
- :param skipNulls - whether to skip nulls when joining in values
- :param sql_join_opt - if set to True, will use standard Spark SQL join if it is estimated to be efficient
- :param suppress_null_warning - when tsPartitionVal is specified, will collect min of each column and raise warnings about null values, set to True to avoid
- """
-
- # Check whether partition columns have the same name in both dataframes
- self.__hasSameSeriesIDs(right)
-
- # validate timestamp datatypes match
- self.__validateTsColMatch(right)
-
- # execute the broadcast-join variation
- # choose 30MB as the cutoff for the broadcast
- bytes_threshold = 30 * 1024 * 1024
- spark = SparkSession.builder.getOrCreate()
- left_bytes = self.__getBytesFromPlan(self.df, spark)
- right_bytes = self.__getBytesFromPlan(right.df, spark)
- if sql_join_opt & ((left_bytes < bytes_threshold)
- | (right_bytes < bytes_threshold)):
- spark.conf.set("spark.databricks.optimizer.rangeJoin.binSize", "60")
- return self.__broadcastAsOfJoin(right)
-
- # perform as-of join.
- if tsPartitionVal is None:
- return self.__standardAsOfJoin(right,
- left_prefix,
- right_prefix,
- skipNulls,
- suppress_null_warning)
- else:
- return self.__skewAsOfJoin(right,
- left_prefix,
- right_prefix,
- tsPartitionVal,
- skipNulls=skipNulls,
- suppress_null_warning=suppress_null_warning)
-
#
# Slicing & Selection
#
@@ -752,7 +425,9 @@ def select(self, *cols):
if set(self.structural_cols).issubset(set(cols)):
return self.__withTransformedDF(self.df.select(*cols))
else:
- raise TSDFStructureChangeError("select that does not include all structural columns")
+ raise Exception(
+ "In TSDF's select statement original ts_col, partitionCols and seq_col_stub(optional) must be present"
+ )
def __slice(self, op: str, target_ts):
"""
@@ -766,7 +441,7 @@ def __slice(self, op: str, target_ts):
"""
# quote our timestamp if its a string
target_expr = f"'{target_ts}'" if isinstance(target_ts, str) else target_ts
- slice_expr = Fn.expr(f"{self.ts_col} {op} {target_expr}")
+ slice_expr = f.expr(f"{self.ts_col} {op} {target_expr}")
sliced_df = self.df.where(slice_expr)
return self.__withTransformedDF(sliced_df)
@@ -846,8 +521,8 @@ def __top_rows_per_series(self, win: WindowSpec, n: int):
"""
row_num_col = "__row_num"
prev_records_df = (
- self.df.withColumn(row_num_col, Fn.row_number().over(win))
- .where(Fn.col(row_num_col) <= Fn.lit(n))
+ self.df.withColumn(row_num_col, f.row_number().over(win))
+ .where(f.col(row_num_col) <= f.lit(n))
.drop(row_num_col)
)
return self.__withTransformedDF(prev_records_df)
@@ -951,29 +626,29 @@ def describe(self):
# extract the double version of the timestamp column to summarize
double_ts_col = self.ts_col + "_dbl"
- this_df = self.df.withColumn(double_ts_col, Fn.col(self.ts_col).cast("double"))
+ this_df = self.df.withColumn(double_ts_col, f.col(self.ts_col).cast("double"))
# summary missing value percentages
missing_vals = this_df.select(
[
(
100
- * Fn.count(Fn.when(Fn.col(c[0]).isNull(), c[0]))
- / Fn.count(Fn.lit(1))
+ * f.count(f.when(f.col(c[0]).isNull(), c[0]))
+ / f.count(f.lit(1))
).alias(c[0])
for c in this_df.dtypes
if c[1] != "timestamp"
]
- ).select(Fn.lit("missing_vals_pct").alias("summary"), "*")
+ ).select(f.lit("missing_vals_pct").alias("summary"), "*")
# describe stats
desc_stats = this_df.describe().union(missing_vals)
unique_ts = this_df.select(*self.series_ids).distinct().count()
- max_ts = this_df.select(Fn.max(Fn.col(self.ts_col)).alias("max_ts")).collect()[0][
+ max_ts = this_df.select(f.max(f.col(self.ts_col)).alias("max_ts")).collect()[0][
0
]
- min_ts = this_df.select(Fn.min(Fn.col(self.ts_col)).alias("max_ts")).collect()[0][
+ min_ts = this_df.select(f.min(f.col(self.ts_col)).alias("max_ts")).collect()[0][
0
]
gran = this_df.selectExpr(
@@ -989,22 +664,22 @@ def describe(self):
non_summary_cols = [c for c in desc_stats.columns if c != "summary"]
desc_stats = desc_stats.select(
- Fn.col("summary"),
- Fn.lit(" ").alias("unique_ts_count"),
- Fn.lit(" ").alias("min_ts"),
- Fn.lit(" ").alias("max_ts"),
- Fn.lit(" ").alias("granularity"),
+ f.col("summary"),
+ f.lit(" ").alias("unique_ts_count"),
+ f.lit(" ").alias("min_ts"),
+ f.lit(" ").alias("max_ts"),
+ f.lit(" ").alias("granularity"),
*non_summary_cols,
)
# add in single record with global summary attributes and the previously computed missing value and Spark data frame describe stats
global_smry_rec = desc_stats.limit(1).select(
- Fn.lit("global").alias("summary"),
- Fn.lit(unique_ts).alias("unique_ts_count"),
- Fn.lit(min_ts).alias("min_ts"),
- Fn.lit(max_ts).alias("max_ts"),
- Fn.lit(gran).alias("granularity"),
- *[Fn.lit(" ").alias(c) for c in non_summary_cols],
+ f.lit("global").alias("summary"),
+ f.lit(unique_ts).alias("unique_ts_count"),
+ f.lit(min_ts).alias("min_ts"),
+ f.lit(max_ts).alias("max_ts"),
+ f.lit(gran).alias("granularity"),
+ *[f.lit(" ").alias(c) for c in non_summary_cols],
)
full_smry = global_smry_rec.union(desc_stats)
@@ -1019,83 +694,236 @@ def describe(self):
return full_smry
pass
- #
- # Window helper functions
- #
+ def __getBytesFromPlan(self, df, spark):
+ """
+ Internal helper function to obtain how many bytes in memory the Spark data frame is likely to take up. This is an upper bound and is obtained from the plan details in Spark
- def __baseWindow(self, reverse=False):
- # The index will determine the appropriate sort order
- w = Window().orderBy(self.ts_index.orderByExpr(reverse))
+ Parameters
+ :param df - input Spark data frame - the AS OF join has 2 data frames; this will be called for each
+ :param spark - Spark session which is used to query the view obtained from the Spark data frame
+ """
- # and partitioned by any series IDs
- if self.series_ids:
- w = w.partitionBy([Fn.col(sid) for sid in self.series_ids])
- return w
+ df.createOrReplaceTempView("view")
+ plan = spark.sql("explain cost select * from view").collect()[0][0]
- def __rowsBetweenWindow(self, rows_from, rows_to, reverse=False):
- return self.__baseWindow(reverse=reverse).rowsBetween(rows_from, rows_to)
+ import re
- def __rangeBetweenWindow(self, range_from, range_to, reverse=False):
- return ( self.__baseWindow(reverse=reverse)
- .orderBy(self.ts_index.rangeOrderByExpr(reverse=reverse))
- .rangeBetween(range_from, range_to ) )
+ result = (
+ re.search(r"sizeInBytes=.*(['\)])", plan, re.MULTILINE)
+ .group(0)
+ .replace(")", "")
+ )
+ size = result.split("=")[1].split(" ")[0]
+ units = result.split("=")[1].split(" ")[1]
- #
- # Core Transformations
- #
+ # perform to MB for threshold check
+ if units == "GiB":
+ bytes = float(size) * 1024 * 1024 * 1024
+ elif units == "MiB":
+ bytes = float(size) * 1024 * 1024
+ elif units == "KiB":
+ bytes = float(size) * 1024
+ else:
+ bytes = float(size)
- def withColumn(self, colName: str, col: Column) -> "TSDF":
- """
- Returns a new :class:`TSDF` by adding a column or replacing the existing column that has the same name.
+ return bytes
- :param colName: the name of the new column (or existing column to be replaced)
- :param col: a :class:`Column` expression for the new column definition
+ def asofJoin(
+ self,
+ right_tsdf,
+ left_prefix=None,
+ right_prefix="right",
+ tsPartitionVal=None,
+ fraction=0.5,
+ skipNulls=True,
+ sql_join_opt=False,
+ suppress_null_warning=False,
+ ):
"""
- if colName in self.structural_cols:
- raise TSDFStructureChangeError(f"withColumn on the structural column {colName}.")
- new_df = self.df.withColumn(colName, col)
- return self.__withTransformedDF(new_df)
+ Performs an as-of join between two time-series. If a tsPartitionVal is specified, it will do this partitioned by
+ time brackets, which can help alleviate skew.
- def withColumnRenamed(self, existing: str, new: str) -> "TSDF":
- """
- Returns a new :class:`TSDF` with the given column renamed.
+ NOTE: partition cols have to be the same for both Dataframes. We are collecting stats when the WARNING level is
+ enabled also.
- :param existing: name of the existing column to renmame
- :param new: new name for the column
+ Parameters
+ :param right_tsdf - right-hand data frame containing columns to merge in
+ :param left_prefix - optional prefix for base data frame
+ :param right_prefix - optional prefix for right-hand data frame
+ :param tsPartitionVal - value to break up each partition into time brackets
+ :param fraction - overlap fraction
+ :param skipNulls - whether to skip nulls when joining in values
+ :param sql_join_opt - if set to True, will use standard Spark SQL join if it is estimated to be efficient
+ :param suppress_null_warning - when tsPartitionVal is specified, will collect min of each column and raise warnings about null values, set to True to avoid
"""
- # create new TSIndex
- new_ts_index = deepcopy(self.ts_index)
- if existing == self.ts_index.name:
- new_ts_index = new_ts_index.renamed(new)
+ # first block of logic checks whether a standard range join will suffice
+ left_df = self.df
+ right_df = right_tsdf.df
- # and for series ids
- new_series_ids = self.series_ids
- if existing in self.series_ids:
- # replace column name in series
- new_series_ids = self.series_ids
- new_series_ids[new_series_ids.index(existing)] = new
+ spark = SparkSession.builder.getOrCreate()
+ left_bytes = self.__getBytesFromPlan(left_df, spark)
+ right_bytes = self.__getBytesFromPlan(right_df, spark)
- # rename the column in the underlying DF
- new_df = self.df.withColumnRenamed(existing,new)
+ # choose 30MB as the cutoff for the broadcast
+ bytes_threshold = 30 * 1024 * 1024
+ if sql_join_opt & (
+ (left_bytes < bytes_threshold) | (right_bytes < bytes_threshold)
+ ):
+ spark.conf.set("spark.databricks.optimizer.rangeJoin.binSize", 60)
+ partition_cols = right_tsdf.series_ids
+ left_cols = list(set(left_df.columns).difference(set(self.series_ids)))
+ right_cols = list(
+ set(right_df.columns).difference(set(right_tsdf.series_ids))
+ )
- # return new TSDF
- new_schema = TSSchema(new_ts_index, new_series_ids)
- return TSDF(new_df, ts_schema=new_schema)
+ left_prefix = (
+ ""
+ if ((left_prefix is None) | (left_prefix == ""))
+ else left_prefix + "_"
+ )
+ right_prefix = (
+ ""
+ if ((right_prefix is None) | (right_prefix == ""))
+ else right_prefix + "_"
+ )
- def union(self, other: TSDF) -> TSDF:
- # union of the underlying DataFrames
- union_df = self.df.union(other.df)
- return self.__withTransformedDF(union_df)
+ w = Window.partitionBy(*partition_cols).orderBy(
+ right_prefix + right_tsdf.ts_col
+ )
- def unionByName(self, other: TSDF, allowMissingColumns: bool = False) -> TSDF:
- # union of the underlying DataFrames
- union_df = self.df.unionByName(other.df, allowMissingColumns=allowMissingColumns)
- return self.__withTransformedDF(union_df)
+ new_left_ts_col = left_prefix + self.ts_col
+ new_left_cols = [
+ f.col(c).alias(left_prefix + c) for c in left_cols
+ ] + partition_cols
+ new_right_cols = [
+ f.col(c).alias(right_prefix + c) for c in right_cols
+ ] + partition_cols
+ quotes_df_w_lag = right_df.select(*new_right_cols).withColumn(
+ "lead_" + right_tsdf.ts_col,
+ f.lead(right_prefix + right_tsdf.ts_col).over(w),
+ )
+ left_df = left_df.select(*new_left_cols)
+ res = (
+ left_df.join(quotes_df_w_lag, partition_cols)
+ .where(
+ left_df[new_left_ts_col].between(
+ f.col(right_prefix + right_tsdf.ts_col),
+ f.coalesce(
+ f.col("lead_" + right_tsdf.ts_col),
+ f.lit("2099-01-01").cast("timestamp"),
+ ),
+ )
+ )
+ .drop("lead_" + right_tsdf.ts_col)
+ )
+ return TSDF(res, series_ids=self.series_ids, ts_col=new_left_ts_col)
- #
- # utility functions
- #
+ # end of block checking to see if standard Spark SQL join will work
+
+ if tsPartitionVal is not None:
+ logger.warning(
+ "You are using the skew version of the AS OF join. This may result in null values if there are any values outside of the maximum lookback. For maximum efficiency, choose smaller values of maximum lookback, trading off performance and potential blank AS OF values for sparse keys"
+ )
+
+ # Check whether partition columns have same name in both dataframes
+ self.__checkPartitionCols(right_tsdf)
+
+ # prefix non-partition columns, to avoid duplicated columns.
+ left_df = self.df
+ right_df = right_tsdf.df
+
+ # validate timestamp datatypes match
+ self.__validateTsColMatch(right_tsdf)
+
+ orig_left_col_diff = list(set(left_df.columns).difference(set(self.series_ids)))
+ orig_right_col_diff = list(
+ set(right_df.columns).difference(set(self.series_ids))
+ )
+
+ left_tsdf = (
+ (self.__addPrefixToColumns([self.ts_col] + orig_left_col_diff, left_prefix))
+ if left_prefix is not None
+ else self
+ )
+ right_tsdf = right_tsdf.__addPrefixToColumns(
+ [right_tsdf.ts_col] + orig_right_col_diff, right_prefix
+ )
+
+ left_nonpartition_cols = list(
+ set(left_tsdf.df.columns).difference(set(self.series_ids))
+ )
+ right_nonpartition_cols = list(
+ set(right_tsdf.df.columns).difference(set(self.series_ids))
+ )
+
+ # For both dataframes get all non-partition columns (including ts_col)
+ left_columns = [left_tsdf.ts_col] + left_nonpartition_cols
+ right_columns = [right_tsdf.ts_col] + right_nonpartition_cols
+
+ # Union both dataframes, and create a combined TS column
+ combined_ts_col = "combined_ts"
+ combined_df = left_tsdf.__addColumnsFromOtherDF(right_columns).__combineTSDF(
+ right_tsdf.__addColumnsFromOtherDF(left_columns), combined_ts_col
+ )
+ combined_df.df = combined_df.df.withColumn(
+ "rec_ind", f.when(f.col(left_tsdf.ts_col).isNotNull(), 1).otherwise(-1)
+ )
+
+ # perform asof join.
+ if tsPartitionVal is None:
+ seq_col = None
+ if isinstance(combined_df.ts_index, SubSequenceTSIndex):
+ seq_col = combined_df.ts_index.sub_seq_col
+ asofDF = combined_df.__getLastRightRow(
+ left_tsdf.ts_col,
+ right_columns,
+ seq_col,
+ tsPartitionVal,
+ skipNulls,
+ suppress_null_warning,
+ )
+ else:
+ tsPartitionDF = combined_df.__getTimePartitions(
+ tsPartitionVal, fraction=fraction
+ )
+ seq_col = None
+ if isinstance(tsPartitionDF.ts_index, SubSequenceTSIndex):
+ seq_col = tsPartitionDF.ts_index.sub_seq_col
+ asofDF = tsPartitionDF.__getLastRightRow(
+ left_tsdf.ts_col,
+ right_columns,
+ seq_col,
+ tsPartitionVal,
+ skipNulls,
+ suppress_null_warning,
+ )
+
+ # Get rid of overlapped data and the extra columns generated from timePartitions
+ df = asofDF.df.filter(f.col("is_original") == 1).drop(
+ "ts_partition", "is_original"
+ )
+
+ asofDF = TSDF(df, ts_col=asofDF.ts_col, series_ids=combined_df.series_ids)
+
+ return asofDF
+
+ def __baseWindow(self, reverse=False):
+ # The index will determine the appropriate sort order
+ w = Window().orderBy(self.ts_index.orderByExpr(reverse))
+
+ # and partitioned by any series IDs
+ if self.series_ids:
+ w = w.partitionBy([f.col(sid) for sid in self.series_ids])
+ return w
+
+ def __rowsBetweenWindow(self, rows_from, rows_to, reverse=False):
+ return self.__baseWindow(reverse=reverse).rowsBetween(rows_from, rows_to)
+
+ def __rangeBetweenWindow(self, range_from, range_to, reverse=False):
+ return ( self.__baseWindow(reverse=reverse)
+ .orderBy(self.ts_index.rangeOrderByExpr(reverse=reverse))
+ .rangeBetween(range_from, range_to ) )
def vwap(self, frequency="m", volume_col="volume", price_col="price"):
# set pre_vwap as self or enrich with the frequency
@@ -1103,33 +931,33 @@ def vwap(self, frequency="m", volume_col="volume", price_col="price"):
if frequency == "m":
pre_vwap = self.df.withColumn(
"time_group",
- Fn.concat(
- Fn.lpad(Fn.hour(Fn.col(self.ts_col)), 2, "0"),
- Fn.lit(":"),
- Fn.lpad(Fn.minute(Fn.col(self.ts_col)), 2, "0"),
+ f.concat(
+ f.lpad(f.hour(f.col(self.ts_col)), 2, "0"),
+ f.lit(":"),
+ f.lpad(f.minute(f.col(self.ts_col)), 2, "0"),
),
)
elif frequency == "H":
pre_vwap = self.df.withColumn(
- "time_group", Fn.concat(Fn.lpad(Fn.hour(Fn.col(self.ts_col)), 2, "0"))
+ "time_group", f.concat(f.lpad(f.hour(f.col(self.ts_col)), 2, "0"))
)
elif frequency == "D":
pre_vwap = self.df.withColumn(
- "time_group", Fn.concat(Fn.lpad(Fn.day(Fn.col(self.ts_col)), 2, "0"))
+ "time_group", f.concat(f.lpad(f.day(f.col(self.ts_col)), 2, "0"))
)
group_cols = ["time_group"]
if self.series_ids:
group_cols.extend(self.series_ids)
vwapped = (
- pre_vwap.withColumn("dllr_value", Fn.col(price_col) * Fn.col(volume_col))
+ pre_vwap.withColumn("dllr_value", f.col(price_col) * f.col(volume_col))
.groupby(group_cols)
.agg(
sum("dllr_value").alias("dllr_value"),
sum(volume_col).alias(volume_col),
max(price_col).alias("_".join(["max", price_col])),
)
- .withColumn("vwap", Fn.col("dllr_value") / Fn.col(volume_col))
+ .withColumn("vwap", f.col("dllr_value") / f.col(volume_col))
)
return self.__withTransformedDF(vwapped)
@@ -1143,18 +971,18 @@ def EMA(self, colName, window=30, exp_factor=0.2):
"""
emaColName = "_".join(["EMA", colName])
- df = self.df.withColumn(emaColName, Fn.lit(0)).orderBy(self.ts_col)
+ df = self.df.withColumn(emaColName, f.lit(0)).orderBy(self.ts_col)
w = self.__baseWindow()
# Generate all the lag columns:
for i in range(window):
lagColName = "_".join(["lag", colName, str(i)])
weight = exp_factor * (1 - exp_factor) ** i
- df = df.withColumn(lagColName, weight * Fn.lag(Fn.col(colName), i).over(w))
+ df = df.withColumn(lagColName, weight * f.lag(f.col(colName), i).over(w))
df = df.withColumn(
emaColName,
- Fn.col(emaColName)
- + Fn.when(Fn.col(lagColName).isNull(), Fn.lit(0)).otherwise(
- Fn.col(lagColName)
+ f.col(emaColName)
+ + f.when(f.col(lagColName).isNull(), f.lit(0)).otherwise(
+ f.col(lagColName)
),
).drop(lagColName)
# Nulls are currently removed
@@ -1181,17 +1009,17 @@ def withLookbackFeatures(
"""
# first, join all featureCols into a single array column
tempArrayColName = "__TempArrayCol"
- feat_array_tsdf = self.df.withColumn(tempArrayColName, Fn.array(featureCols))
+ feat_array_tsdf = self.df.withColumn(tempArrayColName, f.array(featureCols))
# construct a lookback array
lookback_win = self.__rowsBetweenWindow(-lookbackWindowSize, -1)
lookback_tsdf = feat_array_tsdf.withColumn(
- featureColName, Fn.collect_list(Fn.col(tempArrayColName)).over(lookback_win)
+ featureColName, f.collect_list(f.col(tempArrayColName)).over(lookback_win)
).drop(tempArrayColName)
# make sure only windows of exact size are allowed
if exactSize:
- return lookback_tsdf.where(Fn.size(featureColName) == lookbackWindowSize)
+ return lookback_tsdf.where(f.size(featureColName) == lookbackWindowSize)
return self.__withTransformedDF(lookback_tsdf)
@@ -1224,16 +1052,16 @@ def withRangeStats(
selectedCols = self.df.columns
derivedCols = []
for metric in colsToSummarize:
- selectedCols.append(Fn.mean(metric).over(w).alias("mean_" + metric))
- selectedCols.append(Fn.count(metric).over(w).alias("count_" + metric))
- selectedCols.append(Fn.min(metric).over(w).alias("min_" + metric))
- selectedCols.append(Fn.max(metric).over(w).alias("max_" + metric))
- selectedCols.append(Fn.sum(metric).over(w).alias("sum_" + metric))
- selectedCols.append(Fn.stddev(metric).over(w).alias("stddev_" + metric))
+ selectedCols.append(f.mean(metric).over(w).alias("mean_" + metric))
+ selectedCols.append(f.count(metric).over(w).alias("count_" + metric))
+ selectedCols.append(f.min(metric).over(w).alias("min_" + metric))
+ selectedCols.append(f.max(metric).over(w).alias("max_" + metric))
+ selectedCols.append(f.sum(metric).over(w).alias("sum_" + metric))
+ selectedCols.append(f.stddev(metric).over(w).alias("stddev_" + metric))
derivedCols.append(
(
- (Fn.col(metric) - Fn.col("mean_" + metric))
- / Fn.col("stddev_" + metric)
+ (f.col(metric) - f.col("mean_" + metric))
+ / f.col("stddev_" + metric)
).alias("zscore_" + metric)
)
selected_df = self.df.select(*selectedCols)
@@ -1274,8 +1102,8 @@ def withGroupedStats(self, metricCols=[], freq=None):
# build window
parsed_freq = rs.checkAllowableFreq(freq)
- agg_window = Fn.window(
- Fn.col(self.ts_col),
+ agg_window = f.window(
+ f.col(self.ts_col),
"{} {}".format(parsed_freq[0], rs.freq_dict[parsed_freq[1]]),
)
@@ -1284,19 +1112,19 @@ def withGroupedStats(self, metricCols=[], freq=None):
for metric in metricCols:
selectedCols.extend(
[
- Fn.mean(Fn.col(metric)).alias("mean_" + metric),
- Fn.count(Fn.col(metric)).alias("count_" + metric),
- Fn.min(Fn.col(metric)).alias("min_" + metric),
- Fn.max(Fn.col(metric)).alias("max_" + metric),
- Fn.sum(Fn.col(metric)).alias("sum_" + metric),
- Fn.stddev(Fn.col(metric)).alias("stddev_" + metric),
+ f.mean(f.col(metric)).alias("mean_" + metric),
+ f.count(f.col(metric)).alias("count_" + metric),
+ f.min(f.col(metric)).alias("min_" + metric),
+ f.max(f.col(metric)).alias("max_" + metric),
+ f.sum(f.col(metric)).alias("sum_" + metric),
+ f.stddev(f.col(metric)).alias("stddev_" + metric),
]
)
selected_df = self.df.groupBy(self.series_ids + [agg_window]).agg(*selectedCols)
summary_df = (
selected_df.select(*selected_df.columns)
- .withColumn(self.ts_col, Fn.col("window").start)
+ .withColumn(self.ts_col, f.col("window").start)
.drop("window")
)
@@ -1462,11 +1290,11 @@ def tempo_fourier_util(pdf):
data = self.df
if self.series_ids == []:
- data = data.withColumn("dummy_group", Fn.lit("dummy_val"))
+ data = data.withColumn("dummy_group", f.lit("dummy_val"))
data = (
- data.select(Fn.col("dummy_group"), self.ts_col, Fn.col(valueCol))
- .withColumn("tdval", Fn.col(valueCol))
- .withColumn("tpoints", Fn.col(self.ts_col))
+ data.select(f.col("dummy_group"), self.ts_col, f.col(valueCol))
+ .withColumn("tdval", f.col(valueCol))
+ .withColumn("tpoints", f.col(self.ts_col))
)
return_schema = ",".join(
[f"{i[0]} {i[1]}" for i in data.dtypes]
@@ -1479,9 +1307,9 @@ def tempo_fourier_util(pdf):
else:
group_cols = self.series_ids
data = (
- data.select(*group_cols, self.ts_col, Fn.col(valueCol))
- .withColumn("tdval", Fn.col(valueCol))
- .withColumn("tpoints", Fn.col(self.ts_col))
+ data.select(*group_cols, self.ts_col, f.col(valueCol))
+ .withColumn("tdval", f.col(valueCol))
+ .withColumn("tpoints", f.col(self.ts_col))
)
return_schema = ",".join(
[f"{i[0]} {i[1]}" for i in data.dtypes]
@@ -1520,7 +1348,7 @@ def extractStateIntervals(
# https://spark.apache.org/docs/latest/sql-ref-null-semantics.html#comparison-operators-
def null_safe_equals(col1: Column, col2: Column) -> Column:
return (
- Fn.when(col1.isNull() & col2.isNull(), True)
+ f.when(col1.isNull() & col2.isNull(), True)
.when(col1.isNull() | col2.isNull(), False)
.otherwise(operator.eq(col1, col2))
)
@@ -1572,7 +1400,7 @@ def state_comparison_fn(a, b):
# Get previous timestamp to identify start time of the interval
data = data.withColumn(
"previous_ts",
- Fn.lag(Fn.col(self.ts_col), offset=1).over(w),
+ f.lag(f.col(self.ts_col), offset=1).over(w),
)
# Determine state intervals using user-provided the state comparison function
@@ -1582,31 +1410,31 @@ def state_comparison_fn(a, b):
temp_metric_compare_col = f"__{mc}_compare"
data = data.withColumn(
temp_metric_compare_col,
- state_comparison_fn(Fn.col(mc), Fn.lag(Fn.col(mc), 1).over(w)),
+ state_comparison_fn(f.col(mc), f.lag(f.col(mc), 1).over(w)),
)
temp_metric_compare_cols.append(temp_metric_compare_col)
# Remove first record which will have no state change
# and produces `null` for all state comparisons
- data = data.filter(Fn.col("previous_ts").isNotNull())
+ data = data.filter(f.col("previous_ts").isNotNull())
# Each state comparison should return True if state remained constant
data = data.withColumn(
- "state_change", Fn.array_contains(Fn.array(*temp_metric_compare_cols), False)
+ "state_change", f.array_contains(f.array(*temp_metric_compare_cols), False)
)
# Count the distinct state changes to get the unique intervals
data = data.withColumn(
"state_incrementer",
- Fn.sum(Fn.col("state_change").cast("int")).over(w),
- ).filter(~Fn.col("state_change"))
+ f.sum(f.col("state_change").cast("int")).over(w),
+ ).filter(~f.col("state_change"))
# Find the start and end timestamp of the interval
result = (
data.groupBy(*self.series_ids, "state_incrementer")
.agg(
- Fn.min("previous_ts").alias("start_ts"),
- Fn.max(self.ts_col).alias("end_ts"),
+ f.min("previous_ts").alias("start_ts"),
+ f.max(self.ts_col).alias("end_ts"),
)
.drop("state_incrementer")
)
diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py
index a93e93f6..4fa1f091 100644
--- a/python/tempo/tsschema.py
+++ b/python/tempo/tsschema.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import Any, Union, Collection, List, Iterable
+from typing import Union, Collection, List
from pyspark.sql import Column
import pyspark.sql.functions as Fn
@@ -15,25 +15,6 @@ class TSIndex(ABC):
Abstract base class for all Timeseries Index types
"""
- def __eq__(self, o: object) -> bool:
- # must be a SimpleTSIndex
- if not isinstance(o, TSIndex):
- return False
- return self.indexAttributes == o.indexAttributes
-
- def __repr__(self) -> str:
- return self.__str__()
-
- def __str__(self) -> str:
- return f"""{self.__class__.__name__}({self.indexAttributes})"""
-
- @property
- @abstractmethod
- def indexAttributes(self) -> dict[str,Any]:
- """
- :return: key attributes of this index
- """
-
@property
@abstractmethod
def name(self) -> str:
@@ -48,16 +29,6 @@ def ts_col(self) -> str:
:return: the name of the primary timeseries column (may or may not be the same as the name)
"""
- @abstractmethod
- def renamed(self, new_name: str) -> "TSIndex":
- """
- Renames the index
-
- :param new_name: new name of the index
-
- :return: a copy of this :class:`TSIndex` object with the new name
- """
-
def _reverseOrNot(self, expr: Union[Column, List[Column]], reverse: bool) -> Union[Column, List[Column]]:
if not reverse:
return expr # just return the expression as-is if we're not reversing
@@ -101,11 +72,6 @@ def __init__(self, ts_col: StructField) -> None:
self.__name = ts_col.name
self.dataType = ts_col.dataType
- @property
- def indexAttributes(self) -> dict[str, Any]:
- return { 'name': self.name,
- 'dataType': self.dataType }
-
@property
def name(self):
return self.__name
@@ -114,10 +80,6 @@ def name(self):
def ts_col(self) -> str:
return self.name
- def renamed(self, new_name: str) -> "TSIndex":
- self.__name = new_name
- return self
-
def orderByExpr(self, reverse: bool = False) -> Column:
expr = Fn.col(self.name)
return self._reverseOrNot(expr, reverse)
@@ -190,20 +152,14 @@ class CompositeTSIndex(TSIndex, ABC):
def __init__(self, composite_ts_idx: StructField, primary_ts_col: str) -> None:
if not isinstance(composite_ts_idx.dataType, StructType):
raise TypeError(f"CompoundTSIndex must be of type StructType, but given compound_ts_idx {composite_ts_idx.name} has type {composite_ts_idx.dataType}")
- self.__name: str = composite_ts_idx.name
+ self.ts_idx: str = composite_ts_idx.name
self.struct: StructType = composite_ts_idx.dataType
# construct a simple TS index object for the primary column
self.primary_ts_idx: SimpleTSIndex = SimpleTSIndex.fromTSCol(self.struct[primary_ts_col])
- @property
- def indexAttributes(self) -> dict[str, Any]:
- return { 'name': self.name,
- 'struct': self.struct,
- 'primary_ts_col': self.primary_ts_idx }
-
@property
def name(self) -> str:
- return self.__name
+ return self.ts_idx
@property
def ts_col(self) -> str:
@@ -213,10 +169,6 @@ def ts_col(self) -> str:
def primary_ts_col(self) -> str:
return self.component(self.primary_ts_idx.name)
- def renamed(self, new_name: str) -> "TSIndex":
- self.__name = new_name
- return self
-
def component(self, component_name):
"""
Returns the full path to a component column that is within the composite index
@@ -244,12 +196,6 @@ def __init__(self, composite_ts_idx: StructField, primary_ts_col: str, sub_seq_c
# construct a simple index for the sub-sequence column
self.sub_sequence_idx = NumericIndex(self.struct[sub_seq_col])
- @property
- def indexAttributes(self) -> dict[str, Any]:
- attrs = super().indexAttributes
- attrs['sub_sequence_idx'] = self.sub_sequence_idx
- return attrs
-
@property
def sub_seq_col(self) -> str:
return self.component(self.sub_sequence_idx.name)
@@ -273,12 +219,6 @@ def __init__(self, composite_ts_idx: StructField, src_str_col: str, parsed_col:
raise TypeError(f"Source string column must be of StringType, but given column {src_str_field.name} is of type {src_str_field.dataType}")
self.__src_str_col = src_str_col
- @property
- def indexAttributes(self) -> dict[str, Any]:
- attrs = super().indexAttributes
- attrs['src_str_col'] = self.src_str_col
- return attrs
-
@property
def src_str_col(self):
return self.component(self.__src_str_col)
@@ -345,27 +285,7 @@ def __init__(
if series_ids:
self.series_ids = list(series_ids)
else:
- self.series_ids = []
-
- def __eq__(self, o: object) -> bool:
- # must be of TSSchema type
- if not isinstance(o, TSSchema):
- return False
- # must have same TSIndex
- if self.ts_idx != o.ts_idx:
- return False
- # must have the same series IDs
- if self.series_ids != o.series_ids:
- return False
- return True
-
- def __repr__(self) -> str:
- return self.__str__()
-
- def __str__(self) -> str:
- return f"""TSSchema({id(self)})
- TSIndex: {self.ts_idx}
- Series IDs: {self.series_ids}"""
+ self.series_ids = None
@classmethod
def fromDFSchema(
@@ -376,20 +296,22 @@ def fromDFSchema(
return cls(ts_idx, series_ids)
@property
- def structural_columns(self) -> list[str]:
+ def structural_columns(self) -> set[str]:
"""
Structural columns are those that define the structure of the :class:`TSDF`. This includes the timeseries column,
a timeseries index (if different), any subsequence column (if present), and the series ID columns.
:return: a set of column names corresponding the structural columns of a :class:`TSDF`
"""
- return list({self.ts_idx.name}.union(self.series_ids))
+ struct_cols = {self.ts_idx.name}.union(self.series_ids)
+ struct_cols.discard(None)
+ return struct_cols
def validate(self, df_schema: StructType) -> None:
pass
- def find_observational_columns(self, df_schema: StructType) -> list[str]:
- return list(set(df_schema.fieldNames()) - set(self.structural_columns))
+ def find_observational_columns(self, df_schema: StructType) -> set[str]:
+ return set(df_schema.fieldNames()) - self.structural_columns
def find_metric_columns(self, df_schema: StructType) -> list[str]:
return [
diff --git a/python/tests/as_of_join_tests.py b/python/tests/as_of_join_tests.py
index 17ef8b32..1ca2b77d 100644
--- a/python/tests/as_of_join_tests.py
+++ b/python/tests/as_of_join_tests.py
@@ -14,10 +14,10 @@ def test_asof_join(self):
noRightPrefixdfExpected = self.get_data_as_sdf("expected_no_right_prefix")
# perform the join
- joined_df = tsdf_left.asOfJoin(
+ joined_df = tsdf_left.asofJoin(
tsdf_right, left_prefix="left", right_prefix="right"
).df
- non_prefix_joined_df = tsdf_left.asOfJoin(
+ non_prefix_joined_df = tsdf_left.asofJoin(
tsdf_right, left_prefix="left", right_prefix=""
).df
@@ -25,7 +25,7 @@ def test_asof_join(self):
self.assertDataFrameEquality(joined_df, dfExpected)
self.assertDataFrameEquality(non_prefix_joined_df, noRightPrefixdfExpected)
- spark_sql_joined_df = tsdf_left.asOfJoin(
+ spark_sql_joined_df = tsdf_left.asofJoin(
tsdf_right, left_prefix="left", right_prefix="right"
).df
self.assertDataFrameEquality(spark_sql_joined_df, dfExpected)
@@ -42,7 +42,7 @@ def test_asof_join_skip_nulls_disabled(self):
)
# perform the join with skip nulls enabled (default)
- joined_df = tsdf_left.asOfJoin(
+ joined_df = tsdf_left.asofJoin(
tsdf_right, left_prefix="left", right_prefix="right"
).df
@@ -50,7 +50,7 @@ def test_asof_join_skip_nulls_disabled(self):
self.assertDataFrameEquality(joined_df, dfExpectedSkipNulls)
# perform the join with skip nulls disabled
- joined_df = tsdf_left.asOfJoin(
+ joined_df = tsdf_left.asofJoin(
tsdf_right, left_prefix="left", right_prefix="right", skipNulls=False
).df
@@ -66,10 +66,10 @@ def test_sequence_number_sort(self):
dfExpected = self.get_data_as_sdf("expected")
# perform the join
- joined_df = tsdf_left.asOfJoin(tsdf_right, right_prefix="right")
+ joined_df = tsdf_left.asofJoin(tsdf_right, right_prefix="right").df
# joined dataframe should equal the expected dataframe
- self.assertDataFrameEquality(joined_df.df, dfExpected)
+ self.assertDataFrameEquality(joined_df, dfExpected)
def test_partitioned_asof_join(self):
"""AS-OF Join with a time-partition"""
@@ -79,7 +79,7 @@ def test_partitioned_asof_join(self):
tsdf_right = self.get_data_as_tsdf("right")
dfExpected = self.get_data_as_sdf("expected")
- joined_df = tsdf_left.asOfJoin(
+ joined_df = tsdf_left.asofJoin(
tsdf_right,
left_prefix="left",
right_prefix="right",
@@ -98,7 +98,7 @@ def test_asof_join_nanos(self):
dfExpected = self.get_data_as_sdf("expected")
# perform join
- joined_df = tsdf_left.asOfJoin(
+ joined_df = tsdf_left.asofJoin(
tsdf_right, left_prefix="left", right_prefix="right"
).df
diff --git a/python/tests/base.py b/python/tests/base.py
index 84b24201..ed8548ec 100644
--- a/python/tests/base.py
+++ b/python/tests/base.py
@@ -144,14 +144,7 @@ def buildTestDF(self, schema, data, ts_cols=["event_ts"]):
# convert timstamp fields to timestamp type
for tsc in ts_cols:
- # check if the column is nested in a struct or not
- if '.' in tsc:
- # we're changing a field nested in a struct
- (struct, field) = tsc.split('.')
- df = df.withColumn(struct, F.col(struct).withField(field, F.to_timestamp(tsc)))
- else:
- # standard column
- df = df.withColumn(tsc, F.to_timestamp(F.col(tsc)))
+ df = df.withColumn(tsc, F.to_timestamp(F.col(tsc)))
return df
#
diff --git a/python/tests/unit_test_data/as_of_join_tests.json b/python/tests/unit_test_data/as_of_join_tests.json
index ff1c0595..f68be584 100644
--- a/python/tests/unit_test_data/as_of_join_tests.json
+++ b/python/tests/unit_test_data/as_of_join_tests.json
@@ -121,16 +121,16 @@
]
},
"expected": {
- "schema": "symbol string, combined_ts struct, event_ts: string, trade_pr float, trade_id int, right_ts_idx struct, right_bid_pr float, right_ask_pr float",
+ "schema": "symbol string, event_ts string, trade_pr float, trade_id int, right_event_ts string, right_bid_pr float, right_ask_pr float, right_seq_nb long",
"ts_col": "event_ts",
"series_ids": ["symbol"],
- "other_ts_cols": ["combined_ts.event_ts", "right_ts_idx.event_ts"],
+ "other_ts_cols": ["right_event_ts"],
"data": [
- ["S1", ["2020-08-01 00:00:10", 1], "2020-08-01 00:00:10", 349.21, 1, ["2020-08-01 00:00:10", 1], 19.11, 20.12],
- ["S1", ["2020-08-01 00:00:10", 1], "2020-08-01 00:00:10", 350.21, 5, ["2020-08-01 00:00:10", 1], 19.11, 20.12],
- ["S1", ["2020-08-01 00:01:12", 3], "2020-08-01 00:01:12", 351.32, 2, ["2020-08-01 00:01:05", 3], 348.10, 1000.13],
- ["S1", ["2020-09-01 00:02:10", 4], "2020-09-01 00:02:10", 361.1, 3, ["2020-09-01 00:02:01", 4], 358.93, 365.12],
- ["S1", ["2020-09-01 00:19:12", 5], "2020-09-01 00:19:12", 362.1, 4, ["2020-09-01 00:15:01", 5], 359.21, 365.31]
+ ["S1", "2020-08-01 00:00:10", 349.21, 1, "2020-08-01 00:00:10", 19.11, 20.12, 1],
+ ["S1", "2020-08-01 00:00:10", 350.21, 5, "2020-08-01 00:00:10", 19.11, 20.12, 1],
+ ["S1", "2020-08-01 00:01:12", 351.32, 2, "2020-08-01 00:01:05", 348.10, 1000.13, 3],
+ ["S1", "2020-09-01 00:02:10", 361.1, 3, "2020-09-01 00:02:01", 358.93, 365.12, 4],
+ ["S1", "2020-09-01 00:19:12", 362.1, 4, "2020-09-01 00:15:01", 359.21, 365.31, 5]
]
}
},
From 99c5e9c36263f717687c24043b51f126646a9cfe Mon Sep 17 00:00:00 2001
From: Tristan Nixon
Date: Mon, 24 Oct 2022 17:06:45 -0700
Subject: [PATCH 08/11] merging changes from integration branch
---
python/tempo/tsdf.py | 476 +++++++++++++++++++++---------------
python/tempo/tsschema.py | 127 ++++++++--
python/tempo/utils.py | 10 +-
python/tests/base.py | 16 +-
python/tests/tsdf_tests.py | 382 +++++++++++++++--------------
python/tests/utils_tests.py | 31 +--
6 files changed, 605 insertions(+), 437 deletions(-)
diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py
index 97b348b9..2afa7120 100644
--- a/python/tempo/tsdf.py
+++ b/python/tempo/tsdf.py
@@ -2,12 +2,12 @@
import logging
import operator
-from functools import reduce
-from typing import List, Union, Callable, Collection, Set
from copy import deepcopy
+from functools import reduce, cached_property
+from typing import List, Union, Callable, Collection
import numpy as np
-import pyspark.sql.functions as f
+import pyspark.sql.functions as Fn
from IPython.core.display import HTML
from IPython.display import display as ipydisplay
from pyspark.sql import SparkSession
@@ -24,11 +24,34 @@
ENV_CAN_RENDER_HTML,
IS_DATABRICKS,
calculate_time_horizon,
- get_display_df,
+ get_display_df
)
logger = logging.getLogger(__name__)
+class TSDFStructureChangeError(Exception):
+ """
+ Error raised when a user attempts an operation that would alter the structure of a TSDF in a destructive manner.
+ """
+ __MSG_TEMPLATE: str = """
+ The attempted operation ({op}) is not allowed because it would result in altering the structure of the TSDF.
+ If you really want to make this change, perform the operation on the underlying DataFrame, then re-create a new TSDF.
+ {d}"""
+
+ def __init__(self, operation: str, details: str = None) -> None:
+ super().__init__(self.__MSG_TEMPLATE.format(op=operation, d=details))
+
+
+class IncompatibleTSError(Exception):
+ """
+ Error raised when an operation is attempted between two incompatible TSDFs.
+ """
+ __MSG_TEMPLATE: str = """
+ The attempted operation ({op}) cannot be performed because the given TSDFs have incompatible structure.
+ {d}"""
+
+ def __init__(self, operation: str, details: str = None) -> None:
+ super().__init__(self.__MSG_TEMPLATE.format(op=operation, d=details))
class TSDF:
"""
@@ -41,7 +64,7 @@ def __init__(
ts_schema: TSSchema = None,
ts_col: str = None,
series_ids: Collection[str] = None,
- validate_schema=True,
+ validate_schema=True
) -> None:
self.df = df
# construct schema if we don't already have one
@@ -53,6 +76,16 @@ def __init__(
if validate_schema:
self.ts_schema.validate(df.schema)
+ def __repr__(self) -> str:
+ return self.__str__()
+
+ def __str__(self) -> str:
+ return f"""TSDF({id(self)}):
+ TS Index: {self.ts_index}
+ Series IDs: {self.series_ids}
+ Observational Cols: {self.observational_cols}
+ DataFrame: {self.df.schema}"""
+
def __withTransformedDF(self, new_df: DataFrame) -> "TSDF":
"""
This helper function will create a new :class:`TSDF` using the current schema, but a new / transformed :class:`DataFrame`
@@ -63,6 +96,19 @@ def __withTransformedDF(self, new_df: DataFrame) -> "TSDF":
"""
return TSDF(new_df, ts_schema=deepcopy(self.ts_schema), validate_schema=False)
+ def __withStandardizedColOrder(self) -> TSDF:
+ """
+ Standardizes the column ordering as such:
+ * series_ids,
+ * ts_index,
+ * observation columns
+
+ :return: a :class:`TSDF` with the columns reordered into "standard order" (as described above)
+ """
+ std_ordered_cols = list(self.series_ids) + [self.ts_index.name] + list(self.observational_cols)
+
+ return self.__withTransformedDF(self.df.select(std_ordered_cols))
+
@classmethod
def __makeStructFromCols(cls, df: DataFrame, struct_col_name: str, cols_to_move: List[str]) -> DataFrame:
"""
@@ -74,8 +120,9 @@ def __makeStructFromCols(cls, df: DataFrame, struct_col_name: str, cols_to_move:
:return: the transformed :class:`DataFrame`
"""
- return df.withColumn(struct_col_name, f.struct(cols_to_move)).drop(*cols_to_move)
+ return df.withColumn(struct_col_name, Fn.struct(cols_to_move)).drop(*cols_to_move)
+ # default column name for constructed timeseries index struct columns
__DEFAULT_TS_IDX_COL = "ts_idx"
@classmethod
@@ -106,102 +153,82 @@ def ts_index(self) -> "TSIndex":
def ts_col(self) -> str:
return self.ts_index.ts_col
+ @property
+ def columns(self) -> List[str]:
+ return self.df.columns
+
@property
def series_ids(self) -> List[str]:
return self.ts_schema.series_ids
@property
- def structural_cols(self) -> Set[str]:
+ def structural_cols(self) -> List[str]:
return self.ts_schema.structural_columns
- @property
+ @cached_property
def observational_cols(self) -> List[str]:
- return list(self.ts_schema.find_observational_columns(self.df.schema))
+ return self.ts_schema.find_observational_columns(self.df.schema)
- @property
+ @cached_property
def metric_cols(self) -> List[str]:
return self.ts_schema.find_metric_columns(self.df.schema)
- # def __init__(self, df, ts_col="event_ts", partition_cols=None, sequence_col=None):
- # """
- # Constructor
- # :param df:
- # :param ts_col:
- # :param partitionCols:
- # :sequence_col every tsdf allows for a tie-breaker secondary sort key
- # """
- # self.ts_col = self.__validated_column(df, ts_col)
- # self.partitionCols = (
- # []
- # if partition_cols is None
- # else self.__validated_columns(df, partition_cols.copy())
- # )
- #
- # self.df = df
- # self.sequence_col = "" if sequence_col is None else sequence_col
- #
- # # Add customized check for string type for the timestamp. If we see a string, we will proactively created a double version of the string timestamp for sorting purposes and rename to ts_col
- # if df.schema[ts_col].dataType == "StringType":
- # sample_ts = df.limit(1).collect()[0][0]
- # self.__validate_ts_string(sample_ts)
- # self.__add_double_ts().withColumnRenamed("double_ts", self.ts_col)
-
#
# Helper functions
#
- def __add_double_ts(self):
- """Add a double (epoch) version of the string timestamp out to nanos"""
- self.df = (
- self.df.withColumn(
- "nanos",
- (
- f.when(
- f.col(self.ts_col).contains("."),
- f.concat(f.lit("0."), f.split(f.col(self.ts_col), "\.")[1]),
- ).otherwise(0)
- ).cast("double"),
- )
- .withColumn("long_ts", f.col(self.ts_col).cast("timestamp").cast("long"))
- .withColumn("double_ts", f.col("long_ts") + f.col("nanos"))
- .drop("nanos")
- .drop("long_ts")
- )
-
- def __validate_ts_string(self, ts_text):
- """Validate the format for the string using Regex matching for ts_string"""
- import re
-
- ts_pattern = "^\d{4}-\d{2}-\d{2}T| \d{2}:\d{2}:\d{2}\.\d*$"
- if re.match(ts_pattern, ts_text) is None:
- raise ValueError(
- "Incorrect data format, should be YYYY-MM-DD HH:MM:SS[.nnnnnnnn]"
- )
+ # def __add_double_ts(self):
+ # """Add a double (epoch) version of the string timestamp out to nanos"""
+ # self.df = (
+ # self.df.withColumn(
+ # "nanos",
+ # (
+ # Fn.when(
+ # Fn.col(self.ts_col).contains("."),
+ # Fn.concat(Fn.lit("0."), Fn.split(Fn.col(self.ts_col), "\.")[1]),
+ # ).otherwise(0)
+ # ).cast("double"),
+ # )
+ # .withColumn("long_ts", Fn.col(self.ts_col).cast("timestamp").cast("long"))
+ # .withColumn("double_ts", Fn.col("long_ts") + Fn.col("nanos"))
+ # .drop("nanos")
+ # .drop("long_ts")
+ # )
- def __validated_column(self, df, colname):
- if type(colname) != str:
- raise TypeError(
- f"Column names must be of type str; found {type(colname)} instead!"
- )
- if colname.lower() not in [col.lower() for col in df.columns]:
- raise ValueError(f"Column {colname} not found in Dataframe")
- return colname
-
- def __validated_columns(self, df, colnames):
- # if provided a string, treat it as a single column
- if type(colnames) == str:
- colnames = [colnames]
- # otherwise we really should have a list or None
- if colnames is None:
- colnames = []
- elif type(colnames) != list:
- raise TypeError(
- f"Columns must be of type list, str, or None; found {type(colnames)} instead!"
- )
- # validate each column
- for col in colnames:
- self.__validated_column(df, col)
- return colnames
+ # def __validate_ts_string(self, ts_text):
+ # """Validate the format for the string using Regex matching for ts_string"""
+ # import re
+ #
+ # ts_pattern = "^\d{4}-\d{2}-\d{2}T| \d{2}:\d{2}:\d{2}\.\d*$"
+ # if re.match(ts_pattern, ts_text) is None:
+ # raise ValueError(
+ # "Incorrect data format, should be YYYY-MM-DD HH:MM:SS[.nnnnnnnn]"
+ # )
+
+ # def __validated_column(self, df, colname):
+ # if type(colname) != str:
+ # raise TypeError(
+ # f"Column names must be of type str; found {type(colname)} instead!"
+ # )
+ # if colname.lower() not in [col.lower() for col in df.columns]:
+ # raise ValueError(f"Column {colname} not found in Dataframe")
+ # return colname
+
+ # def __validated_columns(self, df, colnames):
+ # # if provided a string, treat it as a single column
+ # if type(colnames) == str:
+ # colnames = [colnames]
+ # # otherwise we really should have a list or None
+ # if colnames is None:
+ # colnames = []
+ # elif type(colnames) != list:
+ # raise TypeError(
+ # f"Columns must be of type list, str, or None; found {type(colnames)} instead!"
+ # )
+ # # validate each column
+ # for col in colnames:
+ # self.__validated_column(df, col)
+ # return colnames
def __checkPartitionCols(self, tsdf_right):
for left_col, right_col in zip(self.series_ids, tsdf_right.series_ids):
@@ -246,7 +273,7 @@ def __addColumnsFromOtherDF(self, other_cols):
Add columns from some other DF as lit(None), as pre-step before union.
"""
new_df = reduce(
- lambda df, idx: df.withColumn(other_cols[idx], f.lit(None)),
+ lambda df, idx: df.withColumn(other_cols[idx], Fn.lit(None)),
range(len(other_cols)),
self.df,
)
@@ -255,7 +282,7 @@ def __addColumnsFromOtherDF(self, other_cols):
def __combineTSDF(self, ts_df_right, combined_ts_col):
combined_df = self.df.unionByName(ts_df_right.df).withColumn(
- combined_ts_col, f.coalesce(self.ts_col, ts_df_right.ts_col)
+ combined_ts_col, Fn.coalesce(self.ts_col, ts_df_right.ts_col)
)
return TSDF(combined_df, ts_col=combined_ts_col, series_ids=self.series_ids)
@@ -275,7 +302,7 @@ def __getLastRightRow(
since it is no longer used in subsequent methods.
"""
ptntl_sort_keys = [self.ts_col, "rec_ind", sequence_col]
- sort_keys = [f.col(col_name) for col_name in ptntl_sort_keys if col_name]
+ sort_keys = [Fn.col(col_name) for col_name in ptntl_sort_keys if col_name]
window_spec = (
Window.partitionBy(self.series_ids)
@@ -291,9 +318,9 @@ def __getLastRightRow(
df = reduce(
lambda df, idx: df.withColumn(
right_cols[idx],
- f.last(
- f.when(
- f.col("rec_ind") == -1, f.struct(right_cols[idx])
+ Fn.last(
+ Fn.when(
+ Fn.col("rec_ind") == -1, Fn.struct(right_cols[idx])
).otherwise(None),
True, # ignore nulls because it indicates rows from the left side
).over(window_spec),
@@ -303,7 +330,7 @@ def __getLastRightRow(
)
df = reduce(
lambda df, idx: df.withColumn(
- right_cols[idx], f.col(right_cols[idx])[right_cols[idx]]
+ right_cols[idx], Fn.col(right_cols[idx])[right_cols[idx]]
),
range(len(right_cols)),
df,
@@ -313,7 +340,7 @@ def __getLastRightRow(
df = reduce(
lambda df, idx: df.withColumn(
right_cols[idx],
- f.last(right_cols[idx], ignoreNulls).over(window_spec),
+ Fn.last(right_cols[idx], ignoreNulls).over(window_spec),
),
range(len(right_cols)),
self.df,
@@ -322,16 +349,16 @@ def __getLastRightRow(
df = reduce(
lambda df, idx: df.withColumn(
right_cols[idx],
- f.last(right_cols[idx], ignoreNulls).over(window_spec),
+ Fn.last(right_cols[idx], ignoreNulls).over(window_spec),
).withColumn(
"non_null_ct" + right_cols[idx],
- f.count(right_cols[idx]).over(window_spec),
+ Fn.count(right_cols[idx]).over(window_spec),
),
range(len(right_cols)),
self.df,
)
- df = (df.filter(f.col(left_ts_col).isNotNull()).drop(self.ts_col)).drop(
+ df = (df.filter(Fn.col(left_ts_col).isNotNull()).drop(self.ts_col)).drop(
"rec_ind"
)
@@ -372,26 +399,26 @@ def __getTimePartitions(self, tsPartitionVal, fraction=0.1):
"""
partition_df = (
self.df.withColumn(
- "ts_col_double", f.col(self.ts_col).cast("double")
+ "ts_col_double", Fn.col(self.ts_col).cast("double")
) # double is preferred over unix_timestamp
.withColumn(
"ts_partition",
- f.lit(tsPartitionVal)
- * (f.col("ts_col_double") / f.lit(tsPartitionVal)).cast("integer"),
+ Fn.lit(tsPartitionVal)
+ * (Fn.col("ts_col_double") / Fn.lit(tsPartitionVal)).cast("integer"),
)
.withColumn(
"partition_remainder",
- (f.col("ts_col_double") - f.col("ts_partition"))
- / f.lit(tsPartitionVal),
+ (Fn.col("ts_col_double") - Fn.col("ts_partition"))
+ / Fn.lit(tsPartitionVal),
)
- .withColumn("is_original", f.lit(1))
+ .withColumn("is_original", Fn.lit(1))
).cache() # cache it because it's used twice.
# add [1 - fraction] of previous time partition to the next partition.
remainder_df = (
- partition_df.filter(f.col("partition_remainder") >= f.lit(1 - fraction))
- .withColumn("ts_partition", f.col("ts_partition") + f.lit(tsPartitionVal))
- .withColumn("is_original", f.lit(0))
+ partition_df.filter(Fn.col("partition_remainder") >= Fn.lit(1 - fraction))
+ .withColumn("ts_partition", Fn.col("ts_partition") + Fn.lit(tsPartitionVal))
+ .withColumn("is_original", Fn.lit(0))
)
df = partition_df.union(remainder_df).drop(
@@ -425,9 +452,7 @@ def select(self, *cols):
if set(self.structural_cols).issubset(set(cols)):
return self.__withTransformedDF(self.df.select(*cols))
else:
- raise Exception(
- "In TSDF's select statement original ts_col, partitionCols and seq_col_stub(optional) must be present"
- )
+ raise TSDFStructureChangeError("select that does not include all structural columns")
def __slice(self, op: str, target_ts):
"""
@@ -441,7 +466,7 @@ def __slice(self, op: str, target_ts):
"""
# quote our timestamp if its a string
target_expr = f"'{target_ts}'" if isinstance(target_ts, str) else target_ts
- slice_expr = f.expr(f"{self.ts_col} {op} {target_expr}")
+ slice_expr = Fn.expr(f"{self.ts_col} {op} {target_expr}")
sliced_df = self.df.where(slice_expr)
return self.__withTransformedDF(sliced_df)
@@ -521,8 +546,8 @@ def __top_rows_per_series(self, win: WindowSpec, n: int):
"""
row_num_col = "__row_num"
prev_records_df = (
- self.df.withColumn(row_num_col, f.row_number().over(win))
- .where(f.col(row_num_col) <= f.lit(n))
+ self.df.withColumn(row_num_col, Fn.row_number().over(win))
+ .where(Fn.col(row_num_col) <= Fn.lit(n))
.drop(row_num_col)
)
return self.__withTransformedDF(prev_records_df)
@@ -613,7 +638,7 @@ def show(self, n=20, k=5, truncate=True, vertical=False):
if not (IS_DATABRICKS) and ENV_CAN_RENDER_HTML:
# In Jupyter notebooks, for wide dataframes the below line will enable rendering the output in a scrollable format.
ipydisplay(HTML(""))
- get_display_df(self, k).show(n, truncate, vertical)
+ get_display_df(self,k=k).show(n, truncate, vertical)
def describe(self):
"""
@@ -626,29 +651,29 @@ def describe(self):
# extract the double version of the timestamp column to summarize
double_ts_col = self.ts_col + "_dbl"
- this_df = self.df.withColumn(double_ts_col, f.col(self.ts_col).cast("double"))
+ this_df = self.df.withColumn(double_ts_col, Fn.col(self.ts_col).cast("double"))
# summary missing value percentages
missing_vals = this_df.select(
[
(
100
- * f.count(f.when(f.col(c[0]).isNull(), c[0]))
- / f.count(f.lit(1))
+ * Fn.count(Fn.when(Fn.col(c[0]).isNull(), c[0]))
+ / Fn.count(Fn.lit(1))
).alias(c[0])
for c in this_df.dtypes
if c[1] != "timestamp"
]
- ).select(f.lit("missing_vals_pct").alias("summary"), "*")
+ ).select(Fn.lit("missing_vals_pct").alias("summary"), "*")
# describe stats
desc_stats = this_df.describe().union(missing_vals)
unique_ts = this_df.select(*self.series_ids).distinct().count()
- max_ts = this_df.select(f.max(f.col(self.ts_col)).alias("max_ts")).collect()[0][
+ max_ts = this_df.select(Fn.max(Fn.col(self.ts_col)).alias("max_ts")).collect()[0][
0
]
- min_ts = this_df.select(f.min(f.col(self.ts_col)).alias("max_ts")).collect()[0][
+ min_ts = this_df.select(Fn.min(Fn.col(self.ts_col)).alias("max_ts")).collect()[0][
0
]
gran = this_df.selectExpr(
@@ -664,22 +689,22 @@ def describe(self):
non_summary_cols = [c for c in desc_stats.columns if c != "summary"]
desc_stats = desc_stats.select(
- f.col("summary"),
- f.lit(" ").alias("unique_ts_count"),
- f.lit(" ").alias("min_ts"),
- f.lit(" ").alias("max_ts"),
- f.lit(" ").alias("granularity"),
+ Fn.col("summary"),
+ Fn.lit(" ").alias("unique_ts_count"),
+ Fn.lit(" ").alias("min_ts"),
+ Fn.lit(" ").alias("max_ts"),
+ Fn.lit(" ").alias("granularity"),
*non_summary_cols,
)
# add in single record with global summary attributes and the previously computed missing value and Spark data frame describe stats
global_smry_rec = desc_stats.limit(1).select(
- f.lit("global").alias("summary"),
- f.lit(unique_ts).alias("unique_ts_count"),
- f.lit(min_ts).alias("min_ts"),
- f.lit(max_ts).alias("max_ts"),
- f.lit(gran).alias("granularity"),
- *[f.lit(" ").alias(c) for c in non_summary_cols],
+ Fn.lit("global").alias("summary"),
+ Fn.lit(unique_ts).alias("unique_ts_count"),
+ Fn.lit(min_ts).alias("min_ts"),
+ Fn.lit(max_ts).alias("max_ts"),
+ Fn.lit(gran).alias("granularity"),
+ *[Fn.lit(" ").alias(c) for c in non_summary_cols],
)
full_smry = global_smry_rec.union(desc_stats)
@@ -770,7 +795,7 @@ def asofJoin(
if sql_join_opt & (
(left_bytes < bytes_threshold) | (right_bytes < bytes_threshold)
):
- spark.conf.set("spark.databricks.optimizer.rangeJoin.binSize", 60)
+ spark.conf.set("spark.databricks.optimizer.rangeJoin.binSize", "60")
partition_cols = right_tsdf.series_ids
left_cols = list(set(left_df.columns).difference(set(self.series_ids)))
right_cols = list(
@@ -793,25 +818,21 @@ def asofJoin(
)
new_left_ts_col = left_prefix + self.ts_col
- new_left_cols = [
- f.col(c).alias(left_prefix + c) for c in left_cols
- ] + partition_cols
- new_right_cols = [
- f.col(c).alias(right_prefix + c) for c in right_cols
- ] + partition_cols
+ new_left_cols = [Fn.col(c).alias(left_prefix + c) for c in left_cols] + partition_cols
+ new_right_cols = [Fn.col(c).alias(right_prefix + c) for c in right_cols] + partition_cols
quotes_df_w_lag = right_df.select(*new_right_cols).withColumn(
"lead_" + right_tsdf.ts_col,
- f.lead(right_prefix + right_tsdf.ts_col).over(w),
+ Fn.lead(right_prefix + right_tsdf.ts_col).over(w),
)
left_df = left_df.select(*new_left_cols)
res = (
left_df.join(quotes_df_w_lag, partition_cols)
.where(
left_df[new_left_ts_col].between(
- f.col(right_prefix + right_tsdf.ts_col),
- f.coalesce(
- f.col("lead_" + right_tsdf.ts_col),
- f.lit("2099-01-01").cast("timestamp"),
+ Fn.col(right_prefix + right_tsdf.ts_col),
+ Fn.coalesce(
+ Fn.col("lead_" + right_tsdf.ts_col),
+ Fn.lit("2099-01-01").cast("timestamp"),
),
)
)
@@ -867,7 +888,7 @@ def asofJoin(
right_tsdf.__addColumnsFromOtherDF(left_columns), combined_ts_col
)
combined_df.df = combined_df.df.withColumn(
- "rec_ind", f.when(f.col(left_tsdf.ts_col).isNotNull(), 1).otherwise(-1)
+ "rec_ind", Fn.when(Fn.col(left_tsdf.ts_col).isNotNull(), 1).otherwise(-1)
)
# perform asof join.
@@ -900,7 +921,7 @@ def asofJoin(
)
# Get rid of overlapped data and the extra columns generated from timePartitions
- df = asofDF.df.filter(f.col("is_original") == 1).drop(
+ df = asofDF.df.filter(Fn.col("is_original") == 1).drop(
"ts_partition", "is_original"
)
@@ -914,7 +935,7 @@ def __baseWindow(self, reverse=False):
# and partitioned by any series IDs
if self.series_ids:
- w = w.partitionBy([f.col(sid) for sid in self.series_ids])
+ w = w.partitionBy([Fn.col(sid) for sid in self.series_ids])
return w
def __rowsBetweenWindow(self, rows_from, rows_to, reverse=False):
@@ -925,39 +946,106 @@ def __rangeBetweenWindow(self, range_from, range_to, reverse=False):
.orderBy(self.ts_index.rangeOrderByExpr(reverse=reverse))
.rangeBetween(range_from, range_to ) )
+ #
+ # Core Transformations
+ #
+
+ def withNaturalOrdering(self, reverse: bool = False) -> "TSDF":
+ order_expr = [ Fn.col(c) for c in self.series_ids]
+ ts_idx_expr = self.ts_index.orderByExpr(reverse)
+ if isinstance(ts_idx_expr, list):
+ order_expr.extend(ts_idx_expr)
+ else:
+ order_expr.append(ts_idx_expr)
+
+ return self.__withTransformedDF(self.df.orderBy(order_expr))
+
+ def withColumn(self, colName: str, col: Column) -> "TSDF":
+ """
+ Returns a new :class:`TSDF` by adding a column or replacing the existing column that has the same name.
+
+ :param colName: the name of the new column (or existing column to be replaced)
+ :param col: a :class:`Column` expression for the new column definition
+ """
+ if colName in self.structural_cols:
+ raise TSDFStructureChangeError(f"withColumn on the structural column {colName}.")
+ new_df = self.df.withColumn(colName, col)
+ return self.__withTransformedDF(new_df)
+
+ def withColumnRenamed(self, existing: str, new: str) -> "TSDF":
+ """
+ Returns a new :class:`TSDF` with the given column renamed.
+
+ :param existing: name of the existing column to renmame
+ :param new: new name for the column
+ """
+
+ # create new TSIndex
+ new_ts_index = deepcopy(self.ts_index)
+ if existing == self.ts_index.name:
+ new_ts_index = new_ts_index.renamed(new)
+
+ # and for series ids
+ new_series_ids = self.series_ids
+ if existing in self.series_ids:
+ # replace column name in series
+ new_series_ids = self.series_ids
+ new_series_ids[new_series_ids.index(existing)] = new
+
+ # rename the column in the underlying DF
+ new_df = self.df.withColumnRenamed(existing,new)
+
+ # return new TSDF
+ new_schema = TSSchema(new_ts_index, new_series_ids)
+ return TSDF(new_df, ts_schema=new_schema)
+
+ def union(self, other: TSDF) -> TSDF:
+ # union of the underlying DataFrames
+ union_df = self.df.union(other.df)
+ return self.__withTransformedDF(union_df)
+
+ def unionByName(self, other: TSDF, allowMissingColumns: bool = False) -> TSDF:
+ # union of the underlying DataFrames
+ union_df = self.df.unionByName(other.df, allowMissingColumns=allowMissingColumns)
+ return self.__withTransformedDF(union_df)
+
+ #
+ # utility functions
+ #
+
def vwap(self, frequency="m", volume_col="volume", price_col="price"):
# set pre_vwap as self or enrich with the frequency
pre_vwap = self.df
if frequency == "m":
pre_vwap = self.df.withColumn(
"time_group",
- f.concat(
- f.lpad(f.hour(f.col(self.ts_col)), 2, "0"),
- f.lit(":"),
- f.lpad(f.minute(f.col(self.ts_col)), 2, "0"),
+ Fn.concat(
+ Fn.lpad(Fn.hour(Fn.col(self.ts_col)), 2, "0"),
+ Fn.lit(":"),
+ Fn.lpad(Fn.minute(Fn.col(self.ts_col)), 2, "0"),
),
)
elif frequency == "H":
pre_vwap = self.df.withColumn(
- "time_group", f.concat(f.lpad(f.hour(f.col(self.ts_col)), 2, "0"))
+ "time_group", Fn.concat(Fn.lpad(Fn.hour(Fn.col(self.ts_col)), 2, "0"))
)
elif frequency == "D":
pre_vwap = self.df.withColumn(
- "time_group", f.concat(f.lpad(f.day(f.col(self.ts_col)), 2, "0"))
+ "time_group", Fn.concat(Fn.lpad(Fn.day(Fn.col(self.ts_col)), 2, "0"))
)
group_cols = ["time_group"]
if self.series_ids:
group_cols.extend(self.series_ids)
vwapped = (
- pre_vwap.withColumn("dllr_value", f.col(price_col) * f.col(volume_col))
+ pre_vwap.withColumn("dllr_value", Fn.col(price_col) * Fn.col(volume_col))
.groupby(group_cols)
.agg(
sum("dllr_value").alias("dllr_value"),
sum(volume_col).alias(volume_col),
max(price_col).alias("_".join(["max", price_col])),
)
- .withColumn("vwap", f.col("dllr_value") / f.col(volume_col))
+ .withColumn("vwap", Fn.col("dllr_value") / Fn.col(volume_col))
)
return self.__withTransformedDF(vwapped)
@@ -971,18 +1059,18 @@ def EMA(self, colName, window=30, exp_factor=0.2):
"""
emaColName = "_".join(["EMA", colName])
- df = self.df.withColumn(emaColName, f.lit(0)).orderBy(self.ts_col)
+ df = self.df.withColumn(emaColName, Fn.lit(0)).orderBy(self.ts_col)
w = self.__baseWindow()
# Generate all the lag columns:
for i in range(window):
lagColName = "_".join(["lag", colName, str(i)])
weight = exp_factor * (1 - exp_factor) ** i
- df = df.withColumn(lagColName, weight * f.lag(f.col(colName), i).over(w))
+ df = df.withColumn(lagColName, weight * Fn.lag(Fn.col(colName), i).over(w))
df = df.withColumn(
emaColName,
- f.col(emaColName)
- + f.when(f.col(lagColName).isNull(), f.lit(0)).otherwise(
- f.col(lagColName)
+ Fn.col(emaColName)
+ + Fn.when(Fn.col(lagColName).isNull(), Fn.lit(0)).otherwise(
+ Fn.col(lagColName)
),
).drop(lagColName)
# Nulls are currently removed
@@ -1009,17 +1097,17 @@ def withLookbackFeatures(
"""
# first, join all featureCols into a single array column
tempArrayColName = "__TempArrayCol"
- feat_array_tsdf = self.df.withColumn(tempArrayColName, f.array(featureCols))
+ feat_array_tsdf = self.df.withColumn(tempArrayColName, Fn.array(featureCols))
# construct a lookback array
lookback_win = self.__rowsBetweenWindow(-lookbackWindowSize, -1)
lookback_tsdf = feat_array_tsdf.withColumn(
- featureColName, f.collect_list(f.col(tempArrayColName)).over(lookback_win)
+ featureColName, Fn.collect_list(Fn.col(tempArrayColName)).over(lookback_win)
).drop(tempArrayColName)
# make sure only windows of exact size are allowed
if exactSize:
- return lookback_tsdf.where(f.size(featureColName) == lookbackWindowSize)
+ return lookback_tsdf.where(Fn.size(featureColName) == lookbackWindowSize)
return self.__withTransformedDF(lookback_tsdf)
@@ -1052,16 +1140,16 @@ def withRangeStats(
selectedCols = self.df.columns
derivedCols = []
for metric in colsToSummarize:
- selectedCols.append(f.mean(metric).over(w).alias("mean_" + metric))
- selectedCols.append(f.count(metric).over(w).alias("count_" + metric))
- selectedCols.append(f.min(metric).over(w).alias("min_" + metric))
- selectedCols.append(f.max(metric).over(w).alias("max_" + metric))
- selectedCols.append(f.sum(metric).over(w).alias("sum_" + metric))
- selectedCols.append(f.stddev(metric).over(w).alias("stddev_" + metric))
+ selectedCols.append(Fn.mean(metric).over(w).alias("mean_" + metric))
+ selectedCols.append(Fn.count(metric).over(w).alias("count_" + metric))
+ selectedCols.append(Fn.min(metric).over(w).alias("min_" + metric))
+ selectedCols.append(Fn.max(metric).over(w).alias("max_" + metric))
+ selectedCols.append(Fn.sum(metric).over(w).alias("sum_" + metric))
+ selectedCols.append(Fn.stddev(metric).over(w).alias("stddev_" + metric))
derivedCols.append(
(
- (f.col(metric) - f.col("mean_" + metric))
- / f.col("stddev_" + metric)
+ (Fn.col(metric) - Fn.col("mean_" + metric))
+ / Fn.col("stddev_" + metric)
).alias("zscore_" + metric)
)
selected_df = self.df.select(*selectedCols)
@@ -1102,8 +1190,8 @@ def withGroupedStats(self, metricCols=[], freq=None):
# build window
parsed_freq = rs.checkAllowableFreq(freq)
- agg_window = f.window(
- f.col(self.ts_col),
+ agg_window = Fn.window(
+ Fn.col(self.ts_col),
"{} {}".format(parsed_freq[0], rs.freq_dict[parsed_freq[1]]),
)
@@ -1112,19 +1200,19 @@ def withGroupedStats(self, metricCols=[], freq=None):
for metric in metricCols:
selectedCols.extend(
[
- f.mean(f.col(metric)).alias("mean_" + metric),
- f.count(f.col(metric)).alias("count_" + metric),
- f.min(f.col(metric)).alias("min_" + metric),
- f.max(f.col(metric)).alias("max_" + metric),
- f.sum(f.col(metric)).alias("sum_" + metric),
- f.stddev(f.col(metric)).alias("stddev_" + metric),
+ Fn.mean(Fn.col(metric)).alias("mean_" + metric),
+ Fn.count(Fn.col(metric)).alias("count_" + metric),
+ Fn.min(Fn.col(metric)).alias("min_" + metric),
+ Fn.max(Fn.col(metric)).alias("max_" + metric),
+ Fn.sum(Fn.col(metric)).alias("sum_" + metric),
+ Fn.stddev(Fn.col(metric)).alias("stddev_" + metric),
]
)
selected_df = self.df.groupBy(self.series_ids + [agg_window]).agg(*selectedCols)
summary_df = (
selected_df.select(*selected_df.columns)
- .withColumn(self.ts_col, f.col("window").start)
+ .withColumn(self.ts_col, Fn.col("window").start)
.drop("window")
)
@@ -1286,15 +1374,15 @@ def tempo_fourier_util(pdf):
pdf["freq"] = xf
return pdf[select_cols + ["freq", "ft_real", "ft_imag"]]
- valueCol = self.__validated_column(self.df, valueCol)
+ # valueCol = self.__validated_column(self.df, valueCol)
data = self.df
if self.series_ids == []:
- data = data.withColumn("dummy_group", f.lit("dummy_val"))
+ data = data.withColumn("dummy_group", Fn.lit("dummy_val"))
data = (
- data.select(f.col("dummy_group"), self.ts_col, f.col(valueCol))
- .withColumn("tdval", f.col(valueCol))
- .withColumn("tpoints", f.col(self.ts_col))
+ data.select(Fn.col("dummy_group"), self.ts_col, Fn.col(valueCol))
+ .withColumn("tdval", Fn.col(valueCol))
+ .withColumn("tpoints", Fn.col(self.ts_col))
)
return_schema = ",".join(
[f"{i[0]} {i[1]}" for i in data.dtypes]
@@ -1307,9 +1395,9 @@ def tempo_fourier_util(pdf):
else:
group_cols = self.series_ids
data = (
- data.select(*group_cols, self.ts_col, f.col(valueCol))
- .withColumn("tdval", f.col(valueCol))
- .withColumn("tpoints", f.col(self.ts_col))
+ data.select(*group_cols, self.ts_col, Fn.col(valueCol))
+ .withColumn("tdval", Fn.col(valueCol))
+ .withColumn("tpoints", Fn.col(self.ts_col))
)
return_schema = ",".join(
[f"{i[0]} {i[1]}" for i in data.dtypes]
@@ -1348,7 +1436,7 @@ def extractStateIntervals(
# https://spark.apache.org/docs/latest/sql-ref-null-semantics.html#comparison-operators-
def null_safe_equals(col1: Column, col2: Column) -> Column:
return (
- f.when(col1.isNull() & col2.isNull(), True)
+ Fn.when(col1.isNull() & col2.isNull(), True)
.when(col1.isNull() | col2.isNull(), False)
.otherwise(operator.eq(col1, col2))
)
@@ -1400,7 +1488,7 @@ def state_comparison_fn(a, b):
# Get previous timestamp to identify start time of the interval
data = data.withColumn(
"previous_ts",
- f.lag(f.col(self.ts_col), offset=1).over(w),
+ Fn.lag(Fn.col(self.ts_col), offset=1).over(w),
)
# Determine state intervals using user-provided the state comparison function
@@ -1410,31 +1498,31 @@ def state_comparison_fn(a, b):
temp_metric_compare_col = f"__{mc}_compare"
data = data.withColumn(
temp_metric_compare_col,
- state_comparison_fn(f.col(mc), f.lag(f.col(mc), 1).over(w)),
+ state_comparison_fn(Fn.col(mc), Fn.lag(Fn.col(mc), 1).over(w)),
)
temp_metric_compare_cols.append(temp_metric_compare_col)
# Remove first record which will have no state change
# and produces `null` for all state comparisons
- data = data.filter(f.col("previous_ts").isNotNull())
+ data = data.filter(Fn.col("previous_ts").isNotNull())
# Each state comparison should return True if state remained constant
data = data.withColumn(
- "state_change", f.array_contains(f.array(*temp_metric_compare_cols), False)
+ "state_change", Fn.array_contains(Fn.array(*temp_metric_compare_cols), False)
)
# Count the distinct state changes to get the unique intervals
data = data.withColumn(
"state_incrementer",
- f.sum(f.col("state_change").cast("int")).over(w),
- ).filter(~f.col("state_change"))
+ Fn.sum(Fn.col("state_change").cast("int")).over(w),
+ ).filter(~Fn.col("state_change"))
# Find the start and end timestamp of the interval
result = (
data.groupBy(*self.series_ids, "state_incrementer")
.agg(
- f.min("previous_ts").alias("start_ts"),
- f.max(self.ts_col).alias("end_ts"),
+ Fn.min("previous_ts").alias("start_ts"),
+ Fn.max(self.ts_col).alias("end_ts"),
)
.drop("state_incrementer")
)
diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py
index 4fa1f091..aabb3617 100644
--- a/python/tempo/tsschema.py
+++ b/python/tempo/tsschema.py
@@ -1,11 +1,12 @@
from abc import ABC, abstractmethod
-from typing import Union, Collection, List
+from typing import Any, Union, Collection, List
-from pyspark.sql import Column
import pyspark.sql.functions as Fn
+from pyspark.sql import Column
from pyspark.sql.types import *
from pyspark.sql.types import NumericType
+
#
# Timeseries Index Classes
#
@@ -15,6 +16,25 @@ class TSIndex(ABC):
Abstract base class for all Timeseries Index types
"""
+ def __eq__(self, o: object) -> bool:
+ # must be a SimpleTSIndex
+ if not isinstance(o, TSIndex):
+ return False
+ return self.indexAttributes == o.indexAttributes
+
+ def __repr__(self) -> str:
+ return self.__str__()
+
+ def __str__(self) -> str:
+ return f"""{self.__class__.__name__}({self.indexAttributes})"""
+
+ @property
+ @abstractmethod
+ def indexAttributes(self) -> dict[str,Any]:
+ """
+ :return: key attributes of this index
+ """
+
@property
@abstractmethod
def name(self) -> str:
@@ -29,6 +49,16 @@ def ts_col(self) -> str:
:return: the name of the primary timeseries column (may or may not be the same as the name)
"""
+ @abstractmethod
+ def renamed(self, new_name: str) -> "TSIndex":
+ """
+ Renames the index
+
+ :param new_name: new name of the index
+
+ :return: a copy of this :class:`TSIndex` object with the new name
+ """
+
def _reverseOrNot(self, expr: Union[Column, List[Column]], reverse: bool) -> Union[Column, List[Column]]:
if not reverse:
return expr # just return the expression as-is if we're not reversing
@@ -72,6 +102,11 @@ def __init__(self, ts_col: StructField) -> None:
self.__name = ts_col.name
self.dataType = ts_col.dataType
+ @property
+ def indexAttributes(self) -> dict[str, Any]:
+ return { 'name': self.name,
+ 'dataType': self.dataType }
+
@property
def name(self):
return self.__name
@@ -80,6 +115,10 @@ def name(self):
def ts_col(self) -> str:
return self.name
+ def renamed(self, new_name: str) -> "TSIndex":
+ self.__name = new_name
+ return self
+
def orderByExpr(self, reverse: bool = False) -> Column:
expr = Fn.col(self.name)
return self._reverseOrNot(expr, reverse)
@@ -137,7 +176,7 @@ def __init__(self, ts_col: StructField) -> None:
def rangeOrderByExpr(self, reverse: bool = False) -> Column:
# convert date to number of days since the epoch
expr = Fn.datediff(Fn.col(self.name), Fn.lit("1970-01-01").cast("date"))
- self._reverseOrNot(expr,reverse)
+ return self._reverseOrNot(expr,reverse)
#
# Compound TS Index Types
@@ -152,14 +191,20 @@ class CompositeTSIndex(TSIndex, ABC):
def __init__(self, composite_ts_idx: StructField, primary_ts_col: str) -> None:
if not isinstance(composite_ts_idx.dataType, StructType):
raise TypeError(f"CompoundTSIndex must be of type StructType, but given compound_ts_idx {composite_ts_idx.name} has type {composite_ts_idx.dataType}")
- self.ts_idx: str = composite_ts_idx.name
+ self.__name: str = composite_ts_idx.name
self.struct: StructType = composite_ts_idx.dataType
# construct a simple TS index object for the primary column
self.primary_ts_idx: SimpleTSIndex = SimpleTSIndex.fromTSCol(self.struct[primary_ts_col])
+ @property
+ def indexAttributes(self) -> dict[str, Any]:
+ return { 'name': self.name,
+ 'struct': self.struct,
+ 'primary_ts_col': self.primary_ts_idx }
+
@property
def name(self) -> str:
- return self.ts_idx
+ return self.__name
@property
def ts_col(self) -> str:
@@ -169,6 +214,10 @@ def ts_col(self) -> str:
def primary_ts_col(self) -> str:
return self.component(self.primary_ts_idx.name)
+ def renamed(self, new_name: str) -> "TSIndex":
+ self.__name = new_name
+ return self
+
def component(self, component_name):
"""
Returns the full path to a component column that is within the composite index
@@ -196,6 +245,12 @@ def __init__(self, composite_ts_idx: StructField, primary_ts_col: str, sub_seq_c
# construct a simple index for the sub-sequence column
self.sub_sequence_idx = NumericIndex(self.struct[sub_seq_col])
+ @property
+ def indexAttributes(self) -> dict[str, Any]:
+ attrs = super().indexAttributes
+ attrs['sub_sequence_idx'] = self.sub_sequence_idx
+ return attrs
+
@property
def sub_seq_col(self) -> str:
return self.component(self.sub_sequence_idx.name)
@@ -219,6 +274,12 @@ def __init__(self, composite_ts_idx: StructField, src_str_col: str, parsed_col:
raise TypeError(f"Source string column must be of StringType, but given column {src_str_field.name} is of type {src_str_field.dataType}")
self.__src_str_col = src_str_col
+ @property
+ def indexAttributes(self) -> dict[str, Any]:
+ attrs = super().indexAttributes
+ attrs['src_str_col'] = self.src_str_col
+ return attrs
+
@property
def src_str_col(self):
return self.component(self.__src_str_col)
@@ -253,7 +314,7 @@ def __init__(self, composite_ts_idx: StructField, src_str_col: str, parsed_col:
def rangeOrderByExpr(self, reverse: bool = False) -> Column:
# convert date to number of days since the epoch
expr = Fn.datediff(Fn.col(self.primary_ts_col), Fn.lit("1970-01-01").cast("date"))
- self._reverseOrNot(expr,reverse)
+ return self._reverseOrNot(expr,reverse)
#
# Timseries Schema
@@ -264,18 +325,6 @@ class TSSchema:
Schema type for a :class:`TSDF` class.
"""
- # Valid types for metric columns
- __metric_types = (
- BooleanType(),
- ByteType(),
- ShortType(),
- IntegerType(),
- LongType(),
- DecimalType(),
- FloatType(),
- DoubleType(),
- )
-
def __init__(
self,
ts_idx: TSIndex,
@@ -285,7 +334,27 @@ def __init__(
if series_ids:
self.series_ids = list(series_ids)
else:
- self.series_ids = None
+ self.series_ids = []
+
+ def __eq__(self, o: object) -> bool:
+ # must be of TSSchema type
+ if not isinstance(o, TSSchema):
+ return False
+ # must have same TSIndex
+ if self.ts_idx != o.ts_idx:
+ return False
+ # must have the same series IDs
+ if self.series_ids != o.series_ids:
+ return False
+ return True
+
+ def __repr__(self) -> str:
+ return self.__str__()
+
+ def __str__(self) -> str:
+ return f"""TSSchema({id(self)})
+ TSIndex: {self.ts_idx}
+ Series IDs: {self.series_ids}"""
@classmethod
def fromDFSchema(
@@ -296,28 +365,32 @@ def fromDFSchema(
return cls(ts_idx, series_ids)
@property
- def structural_columns(self) -> set[str]:
+ def structural_columns(self) -> list[str]:
"""
Structural columns are those that define the structure of the :class:`TSDF`. This includes the timeseries column,
a timeseries index (if different), any subsequence column (if present), and the series ID columns.
:return: a set of column names corresponding the structural columns of a :class:`TSDF`
"""
- struct_cols = {self.ts_idx.name}.union(self.series_ids)
- struct_cols.discard(None)
- return struct_cols
+ return list({self.ts_idx.name}.union(self.series_ids))
def validate(self, df_schema: StructType) -> None:
pass
- def find_observational_columns(self, df_schema: StructType) -> set[str]:
- return set(df_schema.fieldNames()) - self.structural_columns
+ def find_observational_columns(self, df_schema: StructType) -> list[str]:
+ return list(set(df_schema.fieldNames()) - set(self.structural_columns))
+
+ @classmethod
+ def is_metric_col(cls, col: StructField) -> bool:
+ return (isinstance(col.dataType, NumericType)
+ or
+ isinstance(col.dataType, BooleanType))
def find_metric_columns(self, df_schema: StructType) -> list[str]:
return [
col.name
for col in df_schema.fields
- if (col.dataType in self.__metric_types)
- and
+ if self.is_metric_col(col)
+ and
(col.name in self.find_observational_columns(df_schema))
]
diff --git a/python/tempo/utils.py b/python/tempo/utils.py
index 0e22030d..5becbb43 100644
--- a/python/tempo/utils.py
+++ b/python/tempo/utils.py
@@ -1,7 +1,8 @@
-from typing import List
import logging
import os
import warnings
+from typing import List
+
from IPython import get_ipython
from IPython.core.display import HTML
from IPython.display import display as ipydisplay
@@ -141,12 +142,7 @@ def display_unavailable(df):
def get_display_df(tsdf, k):
- # let's show the n most recent records per series, in order:
- orderCols = tsdf.partitionCols.copy()
- orderCols.append(tsdf.ts_col)
- if tsdf.sequence_col:
- orderCols.append(tsdf.sequence_col)
- return tsdf.latest(k).df.orderBy(orderCols)
+ return tsdf.latest(k).withNaturalOrdering().df
ENV_CAN_RENDER_HTML = _is_capable_of_html_rendering()
diff --git a/python/tests/base.py b/python/tests/base.py
index ed8548ec..9a4a81d4 100644
--- a/python/tests/base.py
+++ b/python/tests/base.py
@@ -1,17 +1,16 @@
-import re
import os
import unittest
import warnings
from typing import Union
import jsonref
-
import pyspark.sql.functions as F
-from pyspark.sql import SparkSession
-from tempo.tsdf import TSDF
from chispa import assert_df_equality
+from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame
+from tempo.tsdf import TSDF
+
class SparkTest(unittest.TestCase):
#
@@ -144,7 +143,14 @@ def buildTestDF(self, schema, data, ts_cols=["event_ts"]):
# convert timstamp fields to timestamp type
for tsc in ts_cols:
- df = df.withColumn(tsc, F.to_timestamp(F.col(tsc)))
+ # check if the column is nested in a struct or not
+ if '.' in tsc:
+ # we're changing a field nested in a struct
+ (struct, field) = tsc.split('.')
+ df = df.withColumn(struct, F.col(struct).withField(field, F.to_timestamp(tsc)))
+ else:
+ # standard column
+ df = df.withColumn(tsc, F.to_timestamp(F.col(tsc)))
return df
#
diff --git a/python/tests/tsdf_tests.py b/python/tests/tsdf_tests.py
index 3770062e..75e7bac9 100644
--- a/python/tests/tsdf_tests.py
+++ b/python/tests/tsdf_tests.py
@@ -19,8 +19,7 @@ def test_TSDF_init(self):
self.assertIsInstance(tsdf_init.df, DataFrame)
self.assertEqual(tsdf_init.ts_col, "event_ts")
- self.assertEqual(tsdf_init.partitionCols, ["symbol"])
- self.assertEqual(tsdf_init.sequence_col, "")
+ self.assertEqual(tsdf_init.series_ids, ["symbol"])
def test_describe(self):
"""AS-OF Join without a time-partition test"""
@@ -53,12 +52,13 @@ def test_describe(self):
== "2020-09-01 00:19:12"
)
- def test__getBytesFromPlan(self):
- init_tsdf = self.get_data_as_tsdf("init")
-
- _bytes = init_tsdf._TSDF__getBytesFromPlan(init_tsdf.df, self.spark)
-
- self.assertEqual(_bytes, 6.2)
+ # TODO - will be moved to new test suite for asOfJoin
+ # def test__getBytesFromPlan(self):
+ # init_tsdf = self.get_data_as_tsdf("init")
+ #
+ # _bytes = init_tsdf._TSDF__getBytesFromPlan(init_tsdf.df, self.spark)
+ #
+ # self.assertEqual(_bytes, 6.2)
@staticmethod
@mock.patch.dict(os.environ, {"TZ": "UTC"})
@@ -72,180 +72,184 @@ def __tsdf_with_double_tscol(tsdf: TSDF) -> TSDF:
)
return TSDF(with_double_tscol_df, ts_col=tsdf.ts_col, series_ids=tsdf.series_ids)
- def test__add_double_ts(self):
- init_tsdf = self.get_data_as_tsdf("init")
- df = init_tsdf._TSDF__add_double_ts()
-
- schema_string = df.schema.simpleString()
-
- self.assertIn("double_ts:double", schema_string)
-
- def test__validate_ts_string_valid(self):
- valid_timestamp_string = "2020-09-01 00:02:10"
-
- self.assertIsNone(TSDF._TSDF__validate_ts_string(valid_timestamp_string))
-
- def test__validate_ts_string_alt_format_valid(self):
- valid_timestamp_string = "2020-09-01T00:02:10"
-
- self.assertIsNone(TSDF._TSDF__validate_ts_string(valid_timestamp_string))
-
- def test__validate_ts_string_with_microseconds_valid(self):
- valid_timestamp_string = "2020-09-01 00:02:10.00000000"
-
- self.assertIsNone(TSDF._TSDF__validate_ts_string(valid_timestamp_string))
-
- def test__validate_ts_string_alt_format_with_microseconds_valid(self):
- valid_timestamp_string = "2020-09-01T00:02:10.00000000"
-
- self.assertIsNone(TSDF._TSDF__validate_ts_string(valid_timestamp_string))
-
- def test__validate_ts_string_invalid(self):
- invalid_timestamp_string = "this will not work"
-
- self.assertRaises(
- ValueError, TSDF._TSDF__validate_ts_string, invalid_timestamp_string
- )
-
- def test__validated_column_not_string(self):
- init_df = self.get_data_as_tsdf("init").df
-
- self.assertRaises(TypeError, TSDF._TSDF__validated_column, init_df, 0)
-
- def test__validated_column_not_found(self):
- init_df = self.get_data_as_tsdf("init").df
-
- self.assertRaises(
- ValueError,
- TSDF._TSDF__validated_column,
- init_df,
- "does not exist",
- )
-
- def test__validated_column(self):
- init_df = self.get_data_as_tsdf("init").df
-
- self.assertEqual(
- TSDF._TSDF__validated_column(init_df, "symbol"),
- "symbol",
- )
-
- def test__validated_columns_string(self):
- init_tsdf = self.get_data_as_tsdf("init")
-
- self.assertEqual(
- init_tsdf._TSDF__validated_columns(init_tsdf.df, "symbol"),
- ["symbol"],
- )
-
- def test__validated_columns_none(self):
- init_tsdf = self.get_data_as_tsdf("init")
-
- self.assertEqual(
- init_tsdf._TSDF__validated_columns(init_tsdf.df, None),
- [],
- )
-
- def test__validated_columns_tuple(self):
- init_tsdf = self.get_data_as_tsdf("init")
-
- self.assertRaises(
- TypeError,
- init_tsdf._TSDF__validated_columns,
- init_tsdf.df,
- ("symbol",),
- )
-
- def test__validated_columns_list_multiple_elems(self):
- init_tsdf = self.get_data_as_tsdf("init")
-
- self.assertEqual(
- init_tsdf._TSDF__validated_columns(
- init_tsdf.df,
- ["symbol", "event_ts", "trade_pr"],
- ),
- ["symbol", "event_ts", "trade_pr"],
- )
-
- def test__checkPartitionCols(self):
- init_tsdf = self.get_data_as_tsdf("init")
- right_tsdf = self.get_data_as_tsdf("right_tsdf")
-
- self.assertRaises(ValueError, init_tsdf._TSDF__checkPartitionCols, right_tsdf)
-
- def test__validateTsColMatch(self):
- init_tsdf = self.get_data_as_tsdf("init")
- right_tsdf = self.get_data_as_tsdf("right_tsdf")
-
- self.assertRaises(ValueError, init_tsdf._TSDF__validateTsColMatch, right_tsdf)
-
- def test__addPrefixToColumns_non_empty_string(self):
- init_tsdf = self.get_data_as_tsdf("init")
-
- df = init_tsdf._TSDF__addPrefixToColumns(["event_ts"], "prefix").df
-
- schema_string = df.schema.simpleString()
-
- self.assertIn("prefix_event_ts", schema_string)
-
- def test__addPrefixToColumns_empty_string(self):
- init_tsdf = self.get_data_as_tsdf("init")
-
- df = init_tsdf._TSDF__addPrefixToColumns(["event_ts"], "").df
-
- schema_string = df.schema.simpleString()
-
- # comma included (,event_ts) to ensure we don't match if there is a prefix added
- self.assertIn(",event_ts", schema_string)
-
- def test__addColumnsFromOtherDF(self):
- init_tsdf = self.get_data_as_tsdf("init")
-
- df = init_tsdf._TSDF__addColumnsFromOtherDF(["another_col"]).df
-
- schema_string = df.schema.simpleString()
-
- self.assertIn("another_col", schema_string)
-
- def test__combineTSDF(self):
- init1_tsdf = self.get_data_as_tsdf("init")
- init2_tsdf = self.get_data_as_tsdf("init")
-
- union_tsdf = init1_tsdf._TSDF__combineTSDF(init2_tsdf, "combined_ts_col")
- df = union_tsdf.df
-
- schema_string = df.schema.simpleString()
-
- self.assertEqual(init1_tsdf.df.count() + init2_tsdf.df.count(), df.count())
- self.assertIn("combined_ts_col", schema_string)
-
- def test__getLastRightRow(self):
- # TODO: several errors and hard-coded columns that throw AnalysisException
- pass
-
- def test__getTimePartitions(self):
- init_tsdf = self.get_data_as_tsdf("init")
- expected_tsdf = self.get_data_as_tsdf("expected")
-
- actual_tsdf = init_tsdf._TSDF__getTimePartitions(10)
-
- self.assertDataFrameEquality(
- actual_tsdf,
- expected_tsdf,
- from_tsdf=True,
- )
-
- def test__getTimePartitions_with_fraction(self):
- init_tsdf = self.get_data_as_tsdf("init")
- expected_tsdf = self.get_data_as_tsdf("expected")
-
- actual_tsdf = init_tsdf._TSDF__getTimePartitions(10, 0.25)
-
- self.assertDataFrameEquality(
- actual_tsdf,
- expected_tsdf,
- from_tsdf=True,
- )
+ # TODO - replace this with test code for TSDF.fromTimestampString with nano-second precision
+ # def test__add_double_ts(self):
+ # init_tsdf = self.get_data_as_tsdf("init")
+ # df = init_tsdf._TSDF__add_double_ts()
+ #
+ # schema_string = df.schema.simpleString()
+ #
+ # self.assertIn("double_ts:double", schema_string)
+
+ # TODO - replace with tests for TSDF.fromTimestampString
+ # def test__validate_ts_string_valid(self):
+ # valid_timestamp_string = "2020-09-01 00:02:10"
+ #
+ # self.assertIsNone(TSDF._TSDF__validate_ts_string(valid_timestamp_string))
+ #
+ # def test__validate_ts_string_alt_format_valid(self):
+ # valid_timestamp_string = "2020-09-01T00:02:10"
+ #
+ # self.assertIsNone(TSDF._TSDF__validate_ts_string(valid_timestamp_string))
+ #
+ # def test__validate_ts_string_with_microseconds_valid(self):
+ # valid_timestamp_string = "2020-09-01 00:02:10.00000000"
+ #
+ # self.assertIsNone(TSDF._TSDF__validate_ts_string(valid_timestamp_string))
+ #
+ # def test__validate_ts_string_alt_format_with_microseconds_valid(self):
+ # valid_timestamp_string = "2020-09-01T00:02:10.00000000"
+ #
+ # self.assertIsNone(TSDF._TSDF__validate_ts_string(valid_timestamp_string))
+ #
+ # def test__validate_ts_string_invalid(self):
+ # invalid_timestamp_string = "this will not work"
+ #
+ # self.assertRaises(
+ # ValueError, TSDF._TSDF__validate_ts_string, invalid_timestamp_string
+ # )
+
+ # TODO - replace with tests of TSSchema validation
+ # def test__validated_column_not_string(self):
+ # init_df = self.get_data_as_tsdf("init").df
+ #
+ # self.assertRaises(TypeError, TSDF._TSDF__validated_column, init_df, 0)
+ #
+ # def test__validated_column_not_found(self):
+ # init_df = self.get_data_as_tsdf("init").df
+ #
+ # self.assertRaises(
+ # ValueError,
+ # TSDF._TSDF__validated_column,
+ # init_df,
+ # "does not exist",
+ # )
+ #
+ # def test__validated_column(self):
+ # init_df = self.get_data_as_tsdf("init").df
+ #
+ # self.assertEqual(
+ # TSDF._TSDF__validated_column(init_df, "symbol"),
+ # "symbol",
+ # )
+ #
+ # def test__validated_columns_string(self):
+ # init_tsdf = self.get_data_as_tsdf("init")
+ #
+ # self.assertEqual(
+ # init_tsdf._TSDF__validated_columns(init_tsdf.df, "symbol"),
+ # ["symbol"],
+ # )
+ #
+ # def test__validated_columns_none(self):
+ # init_tsdf = self.get_data_as_tsdf("init")
+ #
+ # self.assertEqual(
+ # init_tsdf._TSDF__validated_columns(init_tsdf.df, None),
+ # [],
+ # )
+ #
+ # def test__validated_columns_tuple(self):
+ # init_tsdf = self.get_data_as_tsdf("init")
+ #
+ # self.assertRaises(
+ # TypeError,
+ # init_tsdf._TSDF__validated_columns,
+ # init_tsdf.df,
+ # ("symbol",),
+ # )
+ #
+ # def test__validated_columns_list_multiple_elems(self):
+ # init_tsdf = self.get_data_as_tsdf("init")
+ #
+ # self.assertEqual(
+ # init_tsdf._TSDF__validated_columns(
+ # init_tsdf.df,
+ # ["symbol", "event_ts", "trade_pr"],
+ # ),
+ # ["symbol", "event_ts", "trade_pr"],
+ # )
+
+ # TODO - replace with test code for refactored asOfJoin helpers
+ # def test__checkPartitionCols(self):
+ # init_tsdf = self.get_data_as_tsdf("init")
+ # right_tsdf = self.get_data_as_tsdf("right_tsdf")
+ #
+ # self.assertRaises(ValueError, init_tsdf._TSDF__checkPartitionCols, right_tsdf)
+ #
+ # def test__validateTsColMatch(self):
+ # init_tsdf = self.get_data_as_tsdf("init")
+ # right_tsdf = self.get_data_as_tsdf("right_tsdf")
+ #
+ # self.assertRaises(ValueError, init_tsdf._TSDF__validateTsColMatch, right_tsdf)
+ #
+ # def test__addPrefixToColumns_non_empty_string(self):
+ # init_tsdf = self.get_data_as_tsdf("init")
+ #
+ # df = init_tsdf._TSDF__addPrefixToColumns(["event_ts"], "prefix").df
+ #
+ # schema_string = df.schema.simpleString()
+ #
+ # self.assertIn("prefix_event_ts", schema_string)
+ #
+ # def test__addPrefixToColumns_empty_string(self):
+ # init_tsdf = self.get_data_as_tsdf("init")
+ #
+ # df = init_tsdf._TSDF__addPrefixToColumns(["event_ts"], "").df
+ #
+ # schema_string = df.schema.simpleString()
+ #
+ # # comma included (,event_ts) to ensure we don't match if there is a prefix added
+ # self.assertIn(",event_ts", schema_string)
+ #
+ # def test__addColumnsFromOtherDF(self):
+ # init_tsdf = self.get_data_as_tsdf("init")
+ #
+ # df = init_tsdf._TSDF__addColumnsFromOtherDF(["another_col"]).df
+ #
+ # schema_string = df.schema.simpleString()
+ #
+ # self.assertIn("another_col", schema_string)
+ #
+ # def test__combineTSDF(self):
+ # init1_tsdf = self.get_data_as_tsdf("init")
+ # init2_tsdf = self.get_data_as_tsdf("init")
+ #
+ # union_tsdf = init1_tsdf._TSDF__combineTSDF(init2_tsdf, "combined_ts_col")
+ # df = union_tsdf.df
+ #
+ # schema_string = df.schema.simpleString()
+ #
+ # self.assertEqual(init1_tsdf.df.count() + init2_tsdf.df.count(), df.count())
+ # self.assertIn("combined_ts_col", schema_string)
+ #
+ # def test__getLastRightRow(self):
+ # # TODO: several errors and hard-coded columns that throw AnalysisException
+ # pass
+ #
+ # def test__getTimePartitions(self):
+ # init_tsdf = self.get_data_as_tsdf("init")
+ # expected_tsdf = self.get_data_as_tsdf("expected")
+ #
+ # actual_tsdf = init_tsdf._TSDF__getTimePartitions(10)
+ #
+ # self.assertDataFrameEquality(
+ # actual_tsdf,
+ # expected_tsdf,
+ # from_tsdf=True,
+ # )
+ #
+ # def test__getTimePartitions_with_fraction(self):
+ # init_tsdf = self.get_data_as_tsdf("init")
+ # expected_tsdf = self.get_data_as_tsdf("expected")
+ #
+ # actual_tsdf = init_tsdf._TSDF__getTimePartitions(10, 0.25)
+ #
+ # self.assertDataFrameEquality(
+ # actual_tsdf,
+ # expected_tsdf,
+ # from_tsdf=True,
+ # )
def test_select_empty(self):
# TODO: Can we narrow down to types of Exception?
@@ -811,13 +815,13 @@ def test__rowsBetweenWindow(self):
self.assertIsInstance(init_tsdf._TSDF__rowsBetweenWindow(1, 1), WindowSpec)
- def test_withPartitionCols(self):
- init_tsdf = self.get_data_as_tsdf("init")
-
- actual_tsdf = init_tsdf.withPartitionCols(["symbol"])
-
- self.assertEqual(init_tsdf.partitionCols, [])
- self.assertEqual(actual_tsdf.partitionCols, ["symbol"])
+ # def test_withPartitionCols(self):
+ # init_tsdf = self.get_data_as_tsdf("init")
+ #
+ # actual_tsdf = init_tsdf.withPartitionCols(["symbol"])
+ #
+ # self.assertEqual(init_tsdf.partitionCols, [])
+ # self.assertEqual(actual_tsdf.partitionCols, ["symbol"])
class FourierTransformTest(SparkTest):
diff --git a/python/tests/utils_tests.py b/python/tests/utils_tests.py
index 24c16d69..3f63b669 100644
--- a/python/tests/utils_tests.py
+++ b/python/tests/utils_tests.py
@@ -121,21 +121,22 @@ def test_display_unavailable(self):
],
)
- def test_get_display_df(self):
- init_tsdf = self.get_data_as_tsdf("init")
- expected_df = self.get_data_as_sdf("expected")
-
- actual_df = get_display_df(init_tsdf, 2)
-
- self.assertDataFrameEquality(actual_df, expected_df)
-
- def test_get_display_df_sequence_col(self):
- init_tsdf = self.get_data_as_tsdf("init")
- expected_df = self.get_data_as_sdf("expected")
-
- actual_df = get_display_df(init_tsdf, 2)
-
- self.assertDataFrameEquality(actual_df, expected_df)
+ # TODO - replace with tests of natural ordering & show
+ # def test_get_display_df(self):
+ # init_tsdf = self.get_data_as_tsdf("init")
+ # expected_df = self.get_data_as_sdf("expected")
+ #
+ # actual_df = get_display_df(init_tsdf, 2)
+ #
+ # self.assertDataFrameEquality(actual_df, expected_df)
+ #
+ # def test_get_display_df_sequence_col(self):
+ # init_tsdf = self.get_data_as_tsdf("init")
+ # expected_df = self.get_data_as_sdf("expected")
+ #
+ # actual_df = get_display_df(init_tsdf, 2)
+ #
+ # self.assertDataFrameEquality(actual_df, expected_df)
# MAIN
From 1459ff52bc6c74099f497c68fbb9cb9832f812c5 Mon Sep 17 00:00:00 2001
From: Tristan Nixon
Date: Mon, 24 Oct 2022 18:01:51 -0700
Subject: [PATCH 09/11] black code formatting
---
python/tempo/resample.py | 18 ++----
python/tempo/tsdf.py | 102 +++++++++++++++++++++--------
python/tempo/tsschema.py | 113 +++++++++++++++++++++------------
python/tests/base.py | 12 ++--
python/tests/interpol_tests.py | 7 +-
python/tests/tsdf_tests.py | 4 +-
6 files changed, 166 insertions(+), 90 deletions(-)
diff --git a/python/tempo/resample.py b/python/tempo/resample.py
index 9255b991..baae03a2 100644
--- a/python/tempo/resample.py
+++ b/python/tempo/resample.py
@@ -104,9 +104,7 @@ def aggregate(
exprs = {x: "avg" for x in metricCols}
res = df.groupBy(groupingCols).agg(exprs)
agg_metric_cls = list(
- set(res.columns).difference(
- set(tsdf.series_ids + [tsdf.ts_col, "agg_key"])
- )
+ set(res.columns).difference(set(tsdf.series_ids + [tsdf.ts_col, "agg_key"]))
)
new_cols = [
f.col(c).alias("{}".format(prefix) + (c.split("avg(")[1]).replace(")", ""))
@@ -117,9 +115,7 @@ def aggregate(
exprs = {x: "min" for x in metricCols}
res = df.groupBy(groupingCols).agg(exprs)
agg_metric_cls = list(
- set(res.columns).difference(
- set(tsdf.series_ids + [tsdf.ts_col, "agg_key"])
- )
+ set(res.columns).difference(set(tsdf.series_ids + [tsdf.ts_col, "agg_key"]))
)
new_cols = [
f.col(c).alias("{}".format(prefix) + (c.split("min(")[1]).replace(")", ""))
@@ -130,9 +126,7 @@ def aggregate(
exprs = {x: "max" for x in metricCols}
res = df.groupBy(groupingCols).agg(exprs)
agg_metric_cls = list(
- set(res.columns).difference(
- set(tsdf.series_ids + [tsdf.ts_col, "agg_key"])
- )
+ set(res.columns).difference(set(tsdf.series_ids + [tsdf.ts_col, "agg_key"]))
)
new_cols = [
f.col(c).alias("{}".format(prefix) + (c.split("max(")[1]).replace(")", ""))
@@ -186,9 +180,9 @@ def aggregate(
metrics.append(col[0])
if fill:
- res = imputes.join(
- res, tsdf.series_ids + [tsdf.ts_col], "leftouter"
- ).na.fill(0, metrics)
+ res = imputes.join(res, tsdf.series_ids + [tsdf.ts_col], "leftouter").na.fill(
+ 0, metrics
+ )
return res
diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py
index 2afa7120..8bb118f6 100644
--- a/python/tempo/tsdf.py
+++ b/python/tempo/tsdf.py
@@ -24,15 +24,17 @@
ENV_CAN_RENDER_HTML,
IS_DATABRICKS,
calculate_time_horizon,
- get_display_df
+ get_display_df,
)
logger = logging.getLogger(__name__)
+
class TSDFStructureChangeError(Exception):
"""
Error raised when a user attempts an operation that would alter the structure of a TSDF in a destructive manner.
"""
+
__MSG_TEMPLATE: str = """
The attempted operation ({op}) is not allowed because it would result in altering the structure of the TSDF.
If you really want to make this change, perform the operation on the underlying DataFrame, then re-create a new TSDF.
@@ -46,6 +48,7 @@ class IncompatibleTSError(Exception):
"""
Error raised when an operation is attempted between two incompatible TSDFs.
"""
+
__MSG_TEMPLATE: str = """
The attempted operation ({op}) cannot be performed because the given TSDFs have incompatible structure.
{d}"""
@@ -53,6 +56,7 @@ class IncompatibleTSError(Exception):
def __init__(self, operation: str, details: str = None) -> None:
super().__init__(self.__MSG_TEMPLATE.format(op=operation, d=details))
+
class TSDF:
"""
This object is the main wrapper over a Spark data frame which allows a user to parallelize time series computations on a Spark data frame by various dimensions. The two dimensions required are partition_cols (list of columns by which to summarize) and ts_col (timestamp column, which can be epoch or TimestampType).
@@ -64,7 +68,7 @@ def __init__(
ts_schema: TSSchema = None,
ts_col: str = None,
series_ids: Collection[str] = None,
- validate_schema=True
+ validate_schema=True,
) -> None:
self.df = df
# construct schema if we don't already have one
@@ -105,12 +109,16 @@ def __withStandardizedColOrder(self) -> TSDF:
:return: a :class:`TSDF` with the columns reordered into "standard order" (as described above)
"""
- std_ordered_cols = list(self.series_ids) + [self.ts_index.name] + list(self.observational_cols)
+ std_ordered_cols = (
+ list(self.series_ids) + [self.ts_index.name] + list(self.observational_cols)
+ )
return self.__withTransformedDF(self.df.select(std_ordered_cols))
@classmethod
- def __makeStructFromCols(cls, df: DataFrame, struct_col_name: str, cols_to_move: List[str]) -> DataFrame:
+ def __makeStructFromCols(
+ cls, df: DataFrame, struct_col_name: str, cols_to_move: List[str]
+ ) -> DataFrame:
"""
Transform a :class:`DataFrame` by moving certain columns into a struct
@@ -120,29 +128,50 @@ def __makeStructFromCols(cls, df: DataFrame, struct_col_name: str, cols_to_move:
:return: the transformed :class:`DataFrame`
"""
- return df.withColumn(struct_col_name, Fn.struct(cols_to_move)).drop(*cols_to_move)
+ return df.withColumn(struct_col_name, Fn.struct(cols_to_move)).drop(
+ *cols_to_move
+ )
# default column name for constructed timeseries index struct columns
__DEFAULT_TS_IDX_COL = "ts_idx"
@classmethod
- def fromSubsequenceCol(cls, df: DataFrame, ts_col: str, subsequence_col: str, series_ids: Collection[str] = None) -> "TSDF":
+ def fromSubsequenceCol(
+ cls,
+ df: DataFrame,
+ ts_col: str,
+ subsequence_col: str,
+ series_ids: Collection[str] = None,
+ ) -> "TSDF":
# construct a struct with the ts_col and subsequence_col
struct_col_name = cls.__DEFAULT_TS_IDX_COL
- with_subseq_struct_df = cls.__makeStructFromCols(df, struct_col_name, [ts_col, subsequence_col])
+ with_subseq_struct_df = cls.__makeStructFromCols(
+ df, struct_col_name, [ts_col, subsequence_col]
+ )
# construct an appropriate TSIndex
subseq_struct = with_subseq_struct_df.schema[struct_col_name]
subseq_idx = SubSequenceTSIndex(subseq_struct, ts_col, subsequence_col)
# construct & return the TSDF with appropriate schema
return TSDF(with_subseq_struct_df, ts_schema=TSSchema(subseq_idx, series_ids))
-
@classmethod
- def fromTimestampString(cls, df: DataFrame, ts_col: str, series_ids: Collection[str] = None, ts_fmt: str = "YYYY-MM-DDThh:mm:ss[.SSSSSS]") -> "TSDF":
+ def fromTimestampString(
+ cls,
+ df: DataFrame,
+ ts_col: str,
+ series_ids: Collection[str] = None,
+ ts_fmt: str = "YYYY-MM-DDThh:mm:ss[.SSSSSS]",
+ ) -> "TSDF":
pass
@classmethod
- def fromDateString(cls, df: DataFrame, ts_col: str, series_ids: Collection[str], date_fmt: str = "YYYY-MM-DD") -> "TSDF ":
+ def fromDateString(
+ cls,
+ df: DataFrame,
+ ts_col: str,
+ series_ids: Collection[str],
+ date_fmt: str = "YYYY-MM-DD",
+ ) -> "TSDF ":
pass
@property
@@ -424,7 +453,9 @@ def __getTimePartitions(self, tsPartitionVal, fraction=0.1):
df = partition_df.union(remainder_df).drop(
"partition_remainder", "ts_col_double"
)
- return TSDF(df, ts_col=self.ts_col, series_ids=self.series_ids + ["ts_partition"])
+ return TSDF(
+ df, ts_col=self.ts_col, series_ids=self.series_ids + ["ts_partition"]
+ )
#
# Slicing & Selection
@@ -452,7 +483,9 @@ def select(self, *cols):
if set(self.structural_cols).issubset(set(cols)):
return self.__withTransformedDF(self.df.select(*cols))
else:
- raise TSDFStructureChangeError("select that does not include all structural columns")
+ raise TSDFStructureChangeError(
+ "select that does not include all structural columns"
+ )
def __slice(self, op: str, target_ts):
"""
@@ -638,7 +671,7 @@ def show(self, n=20, k=5, truncate=True, vertical=False):
if not (IS_DATABRICKS) and ENV_CAN_RENDER_HTML:
# In Jupyter notebooks, for wide dataframes the below line will enable rendering the output in a scrollable format.
ipydisplay(HTML(""))
- get_display_df(self,k=k).show(n, truncate, vertical)
+ get_display_df(self, k=k).show(n, truncate, vertical)
def describe(self):
"""
@@ -670,12 +703,12 @@ def describe(self):
desc_stats = this_df.describe().union(missing_vals)
unique_ts = this_df.select(*self.series_ids).distinct().count()
- max_ts = this_df.select(Fn.max(Fn.col(self.ts_col)).alias("max_ts")).collect()[0][
+ max_ts = this_df.select(Fn.max(Fn.col(self.ts_col)).alias("max_ts")).collect()[
0
- ]
- min_ts = this_df.select(Fn.min(Fn.col(self.ts_col)).alias("max_ts")).collect()[0][
+ ][0]
+ min_ts = this_df.select(Fn.min(Fn.col(self.ts_col)).alias("max_ts")).collect()[
0
- ]
+ ][0]
gran = this_df.selectExpr(
"""min(case when {0} - cast({0} as integer) > 0 then '1-millis'
when {0} % 60 != 0 then '2-seconds'
@@ -818,8 +851,12 @@ def asofJoin(
)
new_left_ts_col = left_prefix + self.ts_col
- new_left_cols = [Fn.col(c).alias(left_prefix + c) for c in left_cols] + partition_cols
- new_right_cols = [Fn.col(c).alias(right_prefix + c) for c in right_cols] + partition_cols
+ new_left_cols = [
+ Fn.col(c).alias(left_prefix + c) for c in left_cols
+ ] + partition_cols
+ new_right_cols = [
+ Fn.col(c).alias(right_prefix + c) for c in right_cols
+ ] + partition_cols
quotes_df_w_lag = right_df.select(*new_right_cols).withColumn(
"lead_" + right_tsdf.ts_col,
Fn.lead(right_prefix + right_tsdf.ts_col).over(w),
@@ -942,16 +979,18 @@ def __rowsBetweenWindow(self, rows_from, rows_to, reverse=False):
return self.__baseWindow(reverse=reverse).rowsBetween(rows_from, rows_to)
def __rangeBetweenWindow(self, range_from, range_to, reverse=False):
- return ( self.__baseWindow(reverse=reverse)
- .orderBy(self.ts_index.rangeOrderByExpr(reverse=reverse))
- .rangeBetween(range_from, range_to ) )
+ return (
+ self.__baseWindow(reverse=reverse)
+ .orderBy(self.ts_index.rangeOrderByExpr(reverse=reverse))
+ .rangeBetween(range_from, range_to)
+ )
#
# Core Transformations
#
def withNaturalOrdering(self, reverse: bool = False) -> "TSDF":
- order_expr = [ Fn.col(c) for c in self.series_ids]
+ order_expr = [Fn.col(c) for c in self.series_ids]
ts_idx_expr = self.ts_index.orderByExpr(reverse)
if isinstance(ts_idx_expr, list):
order_expr.extend(ts_idx_expr)
@@ -968,7 +1007,9 @@ def withColumn(self, colName: str, col: Column) -> "TSDF":
:param col: a :class:`Column` expression for the new column definition
"""
if colName in self.structural_cols:
- raise TSDFStructureChangeError(f"withColumn on the structural column {colName}.")
+ raise TSDFStructureChangeError(
+ f"withColumn on the structural column {colName}."
+ )
new_df = self.df.withColumn(colName, col)
return self.__withTransformedDF(new_df)
@@ -993,7 +1034,7 @@ def withColumnRenamed(self, existing: str, new: str) -> "TSDF":
new_series_ids[new_series_ids.index(existing)] = new
# rename the column in the underlying DF
- new_df = self.df.withColumnRenamed(existing,new)
+ new_df = self.df.withColumnRenamed(existing, new)
# return new TSDF
new_schema = TSSchema(new_ts_index, new_series_ids)
@@ -1006,7 +1047,9 @@ def union(self, other: TSDF) -> TSDF:
def unionByName(self, other: TSDF, allowMissingColumns: bool = False) -> TSDF:
# union of the underlying DataFrames
- union_df = self.df.unionByName(other.df, allowMissingColumns=allowMissingColumns)
+ union_df = self.df.unionByName(
+ other.df, allowMissingColumns=allowMissingColumns
+ )
return self.__withTransformedDF(union_df)
#
@@ -1347,7 +1390,9 @@ def calc_bars(tsdf, freq, func=None, metricCols=None, fill=None):
)
bars = bars.select(sel_and_sort)
- return TSDF(bars, ts_col=resample_open.ts_col, series_ids=resample_open.series_ids)
+ return TSDF(
+ bars, ts_col=resample_open.ts_col, series_ids=resample_open.series_ids
+ )
def fourier_transform(self, timestep, valueCol):
"""
@@ -1508,7 +1553,8 @@ def state_comparison_fn(a, b):
# Each state comparison should return True if state remained constant
data = data.withColumn(
- "state_change", Fn.array_contains(Fn.array(*temp_metric_compare_cols), False)
+ "state_change",
+ Fn.array_contains(Fn.array(*temp_metric_compare_cols), False),
)
# Count the distinct state changes to get the unique intervals
diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py
index aabb3617..51b019da 100644
--- a/python/tempo/tsschema.py
+++ b/python/tempo/tsschema.py
@@ -11,6 +11,7 @@
# Timeseries Index Classes
#
+
class TSIndex(ABC):
"""
Abstract base class for all Timeseries Index types
@@ -30,7 +31,7 @@ def __str__(self) -> str:
@property
@abstractmethod
- def indexAttributes(self) -> dict[str,Any]:
+ def indexAttributes(self) -> dict[str, Any]:
"""
:return: key attributes of this index
"""
@@ -59,13 +60,15 @@ def renamed(self, new_name: str) -> "TSIndex":
:return: a copy of this :class:`TSIndex` object with the new name
"""
- def _reverseOrNot(self, expr: Union[Column, List[Column]], reverse: bool) -> Union[Column, List[Column]]:
+ def _reverseOrNot(
+ self, expr: Union[Column, List[Column]], reverse: bool
+ ) -> Union[Column, List[Column]]:
if not reverse:
- return expr # just return the expression as-is if we're not reversing
+ return expr # just return the expression as-is if we're not reversing
elif type(expr) == Column:
- return expr.desc() # reverse a single-expression
+ return expr.desc() # reverse a single-expression
elif type(expr) == List[Column]:
- return [col.desc() for col in expr] # reverse all columns in the expression
+ return [col.desc() for col in expr] # reverse all columns in the expression
@abstractmethod
def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
@@ -88,10 +91,12 @@ def rangeOrderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]
"""
return self.orderByExpr(reverse=reverse)
+
#
# Simple TS Index types
#
+
class SimpleTSIndex(TSIndex, ABC):
"""
Abstract base class for simple Timeseries Index types
@@ -104,8 +109,7 @@ def __init__(self, ts_col: StructField) -> None:
@property
def indexAttributes(self) -> dict[str, Any]:
- return { 'name': self.name,
- 'dataType': self.dataType }
+ return {"name": self.name, "dataType": self.dataType}
@property
def name(self):
@@ -133,7 +137,9 @@ def fromTSCol(cls, ts_col: StructField) -> "SimpleTSIndex":
elif isinstance(ts_col.dataType, DateType):
return SimpleDateIndex(ts_col)
else:
- raise TypeError(f"A SimpleTSIndex must be a Numeric, Timestamp or Date type, but column {ts_col.name} is of type {ts_col.dataType}")
+ raise TypeError(
+ f"A SimpleTSIndex must be a Numeric, Timestamp or Date type, but column {ts_col.name} is of type {ts_col.dataType}"
+ )
class NumericIndex(SimpleTSIndex):
@@ -143,7 +149,9 @@ class NumericIndex(SimpleTSIndex):
def __init__(self, ts_col: StructField) -> None:
if not isinstance(ts_col.dataType, NumericType):
- raise TypeError(f"NumericIndex must be of a numeric type, but ts_col {ts_col.name} has type {ts_col.dataType}")
+ raise TypeError(
+ f"NumericIndex must be of a numeric type, but ts_col {ts_col.name} has type {ts_col.dataType}"
+ )
super().__init__(ts_col)
@@ -154,7 +162,9 @@ class SimpleTimestampIndex(SimpleTSIndex):
def __init__(self, ts_col: StructField) -> None:
if not isinstance(ts_col.dataType, TimestampType):
- raise TypeError(f"SimpleTimestampIndex must be of TimestampType, but given ts_col {ts_col.name} has type {ts_col.dataType}")
+ raise TypeError(
+ f"SimpleTimestampIndex must be of TimestampType, but given ts_col {ts_col.name} has type {ts_col.dataType}"
+ )
super().__init__(ts_col)
def rangeOrderByExpr(self, reverse: bool = False) -> Column:
@@ -170,18 +180,22 @@ class SimpleDateIndex(SimpleTSIndex):
def __init__(self, ts_col: StructField) -> None:
if not isinstance(ts_col.dataType, DateType):
- raise TypeError(f"DateIndex must be of DateType, but given ts_col {ts_col.name} has type {ts_col.dataType}")
+ raise TypeError(
+ f"DateIndex must be of DateType, but given ts_col {ts_col.name} has type {ts_col.dataType}"
+ )
super().__init__(ts_col)
def rangeOrderByExpr(self, reverse: bool = False) -> Column:
# convert date to number of days since the epoch
expr = Fn.datediff(Fn.col(self.name), Fn.lit("1970-01-01").cast("date"))
- return self._reverseOrNot(expr,reverse)
+ return self._reverseOrNot(expr, reverse)
+
#
# Compound TS Index Types
#
+
class CompositeTSIndex(TSIndex, ABC):
"""
Abstract base class for complex Timeseries Index classes
@@ -190,17 +204,23 @@ class CompositeTSIndex(TSIndex, ABC):
def __init__(self, composite_ts_idx: StructField, primary_ts_col: str) -> None:
if not isinstance(composite_ts_idx.dataType, StructType):
- raise TypeError(f"CompoundTSIndex must be of type StructType, but given compound_ts_idx {composite_ts_idx.name} has type {composite_ts_idx.dataType}")
+ raise TypeError(
+ f"CompoundTSIndex must be of type StructType, but given compound_ts_idx {composite_ts_idx.name} has type {composite_ts_idx.dataType}"
+ )
self.__name: str = composite_ts_idx.name
self.struct: StructType = composite_ts_idx.dataType
# construct a simple TS index object for the primary column
- self.primary_ts_idx: SimpleTSIndex = SimpleTSIndex.fromTSCol(self.struct[primary_ts_col])
+ self.primary_ts_idx: SimpleTSIndex = SimpleTSIndex.fromTSCol(
+ self.struct[primary_ts_col]
+ )
@property
def indexAttributes(self) -> dict[str, Any]:
- return { 'name': self.name,
- 'struct': self.struct,
- 'primary_ts_col': self.primary_ts_idx }
+ return {
+ "name": self.name,
+ "struct": self.struct,
+ "primary_ts_col": self.primary_ts_idx,
+ }
@property
def name(self) -> str:
@@ -231,7 +251,7 @@ def component(self, component_name):
def orderByExpr(self, reverse: bool = False) -> Column:
# default to using the primary column
expr = Fn.col(self.primary_ts_col)
- return self._reverseOrNot(expr,reverse)
+ return self._reverseOrNot(expr, reverse)
class SubSequenceTSIndex(CompositeTSIndex):
@@ -240,7 +260,9 @@ class SubSequenceTSIndex(CompositeTSIndex):
column that indicates the
"""
- def __init__(self, composite_ts_idx: StructField, primary_ts_col: str, sub_seq_col: str) -> None:
+ def __init__(
+ self, composite_ts_idx: StructField, primary_ts_col: str, sub_seq_col: str
+ ) -> None:
super().__init__(composite_ts_idx, primary_ts_col)
# construct a simple index for the sub-sequence column
self.sub_sequence_idx = NumericIndex(self.struct[sub_seq_col])
@@ -248,7 +270,7 @@ def __init__(self, composite_ts_idx: StructField, primary_ts_col: str, sub_seq_c
@property
def indexAttributes(self) -> dict[str, Any]:
attrs = super().indexAttributes
- attrs['sub_sequence_idx'] = self.sub_sequence_idx
+ attrs["sub_sequence_idx"] = self.sub_sequence_idx
return attrs
@property
@@ -257,7 +279,7 @@ def sub_seq_col(self) -> str:
def orderByExpr(self, reverse: bool = False) -> List[Column]:
# build a composite expression of the primary index followed by the sub-sequence index
- exprs = [ Fn.col(self.primary_ts_col), Fn.col(self.sub_seq_col) ]
+ exprs = [Fn.col(self.primary_ts_col), Fn.col(self.sub_seq_col)]
return self._reverseOrNot(exprs, reverse)
@@ -267,17 +289,21 @@ class ParsedTSIndex(CompositeTSIndex, ABC):
Retains the original string form as well as the parsed column.
"""
- def __init__(self, composite_ts_idx: StructField, src_str_col: str, parsed_col: str) -> None:
+ def __init__(
+ self, composite_ts_idx: StructField, src_str_col: str, parsed_col: str
+ ) -> None:
super().__init__(composite_ts_idx, primary_ts_col=parsed_col)
src_str_field = self.struct[src_str_col]
if not isinstance(src_str_field.dataType, StringType):
- raise TypeError(f"Source string column must be of StringType, but given column {src_str_field.name} is of type {src_str_field.dataType}")
+ raise TypeError(
+ f"Source string column must be of StringType, but given column {src_str_field.name} is of type {src_str_field.dataType}"
+ )
self.__src_str_col = src_str_col
@property
def indexAttributes(self) -> dict[str, Any]:
attrs = super().indexAttributes
- attrs['src_str_col'] = self.src_str_col
+ attrs["src_str_col"] = self.src_str_col
return attrs
@property
@@ -290,10 +316,14 @@ class ParsedTimestampIndex(ParsedTSIndex):
Timeseries index class for timestamps parsed from a string column
"""
- def __init__(self, composite_ts_idx: StructField, src_str_col: str, parsed_col: str) -> None:
+ def __init__(
+ self, composite_ts_idx: StructField, src_str_col: str, parsed_col: str
+ ) -> None:
super().__init__(composite_ts_idx, src_str_col, parsed_col)
if not isinstance(self.primary_ts_idx.dataType, TimestampType):
- raise TypeError(f"ParsedTimestampIndex must be of TimestampType, but given ts_col {self.primary_ts_idx.name} has type {self.primary_ts_idx.dataType}")
+ raise TypeError(
+ f"ParsedTimestampIndex must be of TimestampType, but given ts_col {self.primary_ts_idx.name} has type {self.primary_ts_idx.dataType}"
+ )
def rangeOrderByExpr(self, reverse: bool = False) -> Column:
# cast timestamp to double (fractional seconds since epoch)
@@ -306,30 +336,34 @@ class ParsedDateIndex(ParsedTSIndex):
Timeseries index class for dates parsed from a string column
"""
- def __init__(self, composite_ts_idx: StructField, src_str_col: str, parsed_col: str) -> None:
+ def __init__(
+ self, composite_ts_idx: StructField, src_str_col: str, parsed_col: str
+ ) -> None:
super().__init__(composite_ts_idx, src_str_col, parsed_col)
if not isinstance(self.primary_ts_idx.dataType, DateType):
- raise TypeError(f"ParsedDateIndex must be of DateType, but given ts_col {self.primary_ts_idx.name} has type {self.primary_ts_idx.dataType}")
+ raise TypeError(
+ f"ParsedDateIndex must be of DateType, but given ts_col {self.primary_ts_idx.name} has type {self.primary_ts_idx.dataType}"
+ )
def rangeOrderByExpr(self, reverse: bool = False) -> Column:
# convert date to number of days since the epoch
- expr = Fn.datediff(Fn.col(self.primary_ts_col), Fn.lit("1970-01-01").cast("date"))
- return self._reverseOrNot(expr,reverse)
+ expr = Fn.datediff(
+ Fn.col(self.primary_ts_col), Fn.lit("1970-01-01").cast("date")
+ )
+ return self._reverseOrNot(expr, reverse)
+
#
# Timseries Schema
#
+
class TSSchema:
"""
Schema type for a :class:`TSDF` class.
"""
- def __init__(
- self,
- ts_idx: TSIndex,
- series_ids: Collection[str] = None
- ) -> None:
+ def __init__(self, ts_idx: TSIndex, series_ids: Collection[str] = None) -> None:
self.ts_idx = ts_idx
if series_ids:
self.series_ids = list(series_ids)
@@ -382,15 +416,14 @@ def find_observational_columns(self, df_schema: StructType) -> list[str]:
@classmethod
def is_metric_col(cls, col: StructField) -> bool:
- return (isinstance(col.dataType, NumericType)
- or
- isinstance(col.dataType, BooleanType))
+ return isinstance(col.dataType, NumericType) or isinstance(
+ col.dataType, BooleanType
+ )
def find_metric_columns(self, df_schema: StructType) -> list[str]:
return [
col.name
for col in df_schema.fields
if self.is_metric_col(col)
- and
- (col.name in self.find_observational_columns(df_schema))
+ and (col.name in self.find_observational_columns(df_schema))
]
diff --git a/python/tests/base.py b/python/tests/base.py
index c4690e24..ae2c03ef 100644
--- a/python/tests/base.py
+++ b/python/tests/base.py
@@ -77,7 +77,9 @@ def get_data_as_tsdf(self, name: str, convert_ts_col=True):
df = self.get_data_as_sdf(name, convert_ts_col)
td = self.test_data[name]
if "sequence_col" in td:
- tsdf = TSDF.fromSubsequenceCol(df, td["ts_col"], td["sequence_col"], td.get("series_ids", None))
+ tsdf = TSDF.fromSubsequenceCol(
+ df, td["ts_col"], td["sequence_col"], td.get("series_ids", None)
+ )
else:
tsdf = TSDF(df, ts_col=td["ts_col"], series_ids=td.get("series_ids", None))
return tsdf
@@ -156,10 +158,12 @@ def buildTestDF(self, schema, data, ts_cols=["event_ts"]):
# convert timstamp fields to timestamp type
for tsc in ts_cols:
# check if the column is nested in a struct or not
- if '.' in tsc:
+ if "." in tsc:
# we're changing a field nested in a struct
- (struct, field) = tsc.split('.')
- df = df.withColumn(struct, F.col(struct).withField(field, F.to_timestamp(tsc)))
+ (struct, field) = tsc.split(".")
+ df = df.withColumn(
+ struct, F.col(struct).withField(field, F.to_timestamp(tsc))
+ )
else:
# standard column
df = df.withColumn(tsc, F.to_timestamp(F.col(tsc)))
diff --git a/python/tests/interpol_tests.py b/python/tests/interpol_tests.py
index 622eea22..149603ad 100644
--- a/python/tests/interpol_tests.py
+++ b/python/tests/interpol_tests.py
@@ -1,11 +1,8 @@
import unittest
-from pyspark.sql.types import *
-
from tempo.interpol import Interpolation
from tempo.tsdf import TSDF
from tempo.utils import *
-
from tests.tsdf_tests import SparkTest
@@ -200,7 +197,7 @@ def test_zero_fill_interpolation_no_perform_checks(self):
# interpolate
actual_df: DataFrame = self.interpolate_helper.interpolate(
tsdf=simple_input_tsdf,
- partition_cols=["partition_a", "partition_b"],
+ series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
@@ -401,7 +398,7 @@ def test_interpolation_using_custom_params(self):
input_tsdf = TSDF(
simple_input_tsdf.df.withColumnRenamed("event_ts", "other_ts_col"),
ts_col="other_ts_col",
- series_ids=["partition_a", "partition_b"]
+ series_ids=["partition_a", "partition_b"],
)
actual_df: DataFrame = input_tsdf.interpolate(
diff --git a/python/tests/tsdf_tests.py b/python/tests/tsdf_tests.py
index 75e7bac9..4695492b 100644
--- a/python/tests/tsdf_tests.py
+++ b/python/tests/tsdf_tests.py
@@ -70,7 +70,9 @@ def __tsdf_with_double_tscol(tsdf: TSDF) -> TSDF:
with_double_tscol_df = tsdf.df.withColumn(
tsdf.ts_col, f.col(tsdf.ts_col).cast("double")
)
- return TSDF(with_double_tscol_df, ts_col=tsdf.ts_col, series_ids=tsdf.series_ids)
+ return TSDF(
+ with_double_tscol_df, ts_col=tsdf.ts_col, series_ids=tsdf.series_ids
+ )
# TODO - replace this with test code for TSDF.fromTimestampString with nano-second precision
# def test__add_double_ts(self):
From f2d266968232063a66acc998b0deb5d78eef4344 Mon Sep 17 00:00:00 2001
From: Tristan Nixon
Date: Wed, 11 Jan 2023 13:44:28 -0800
Subject: [PATCH 10/11] Standardizing pyspark.sql.functions as Fn
---
python/tempo/intervals.py | 42 +++++++--------
python/tempo/io.py | 6 +--
python/tempo/resample.py | 44 +++++++--------
python/tests/base.py | 6 +--
python/tests/intervals_tests.py | 14 ++---
python/tests/tsdf_tests.py | 94 ++++++++++++++++-----------------
6 files changed, 103 insertions(+), 103 deletions(-)
diff --git a/python/tempo/intervals.py b/python/tempo/intervals.py
index 12f7fe66..376eb543 100644
--- a/python/tempo/intervals.py
+++ b/python/tempo/intervals.py
@@ -5,7 +5,7 @@
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import NumericType, BooleanType, StructField
-import pyspark.sql.functions as f
+import pyspark.sql.functions as Fn
from pyspark.sql.window import Window
@@ -205,10 +205,10 @@ def __get_adjacent_rows(self, df: DataFrame) -> DataFrame:
for c in self.interval_boundaries + self.metric_columns:
df = df.withColumn(
f"_lead_1_{c}",
- f.lead(c, 1).over(self.window),
+ Fn.lead(c, 1).over(self.window),
).withColumn(
f"_lag_1_{c}",
- f.lag(c, 1).over(self.window),
+ Fn.lag(c, 1).over(self.window),
)
return df
@@ -231,8 +231,8 @@ def __identify_subset_intervals(self, df: DataFrame) -> tuple[DataFrame, str]:
df = df.withColumn(
subset_indicator,
- (f.col(f"_lag_1_{self.start_ts}") <= f.col(self.start_ts))
- & (f.col(f"_lag_1_{self.end_ts}") >= f.col(self.end_ts)),
+ (Fn.col(f"_lag_1_{self.start_ts}") <= Fn.col(self.start_ts))
+ & (Fn.col(f"_lag_1_{self.end_ts}") >= Fn.col(self.end_ts)),
)
# NB: the first record cannot be a subset of the previous and
@@ -266,12 +266,12 @@ def __identify_overlaps(self, df: DataFrame) -> tuple[DataFrame, list[str]]:
for ts in self.interval_boundaries:
df = df.withColumn(
f"_lead_1_{ts}_overlaps",
- (f.col(f"_lead_1_{ts}") > f.col(self.start_ts))
- & (f.col(f"_lead_1_{ts}") < f.col(self.end_ts)),
+ (Fn.col(f"_lead_1_{ts}") > Fn.col(self.start_ts))
+ & (Fn.col(f"_lead_1_{ts}") < Fn.col(self.end_ts)),
).withColumn(
f"_lag_1_{ts}_overlaps",
- (f.col(f"_lag_1_{ts}") > f.col(self.start_ts))
- & (f.col(f"_lag_1_{ts}") < f.col(self.end_ts)),
+ (Fn.col(f"_lag_1_{ts}") > Fn.col(self.start_ts))
+ & (Fn.col(f"_lag_1_{ts}") < Fn.col(self.end_ts)),
)
overlap_indicators.extend(
@@ -316,9 +316,9 @@ def __merge_adjacent_subset_and_superset(
for c in self.metric_columns:
df = df.withColumn(
c,
- f.when(
- f.col(subset_indicator), f.coalesce(f.col(c), f"_lag_1_{c}")
- ).otherwise(f.col(c)),
+ Fn.when(
+ Fn.col(subset_indicator), Fn.coalesce(Fn.col(c), f"_lag_1_{c}")
+ ).otherwise(Fn.col(c)),
)
return df
@@ -382,7 +382,7 @@ def __merge_adjacent_overlaps(
df = df.withColumn(
new_boundary_col,
- f.expr(new_interval_boundaries),
+ Fn.expr(new_interval_boundaries),
)
if how == "left":
@@ -392,13 +392,13 @@ def __merge_adjacent_overlaps(
c,
# needed when intervals have same start but different ends
# in this case, merge metrics since they overlap
- f.when(
- f.col(f"_lag_1_{self.end_ts}_overlaps"),
- f.coalesce(f.col(c), f.col(f"_lag_1_{c}")),
+ Fn.when(
+ Fn.col(f"_lag_1_{self.end_ts}_overlaps"),
+ Fn.coalesce(Fn.col(c), Fn.col(f"_lag_1_{c}")),
)
# general case when constructing left disjoint interval
# just want new boundary without merging metrics
- .otherwise(f.col(c)),
+ .otherwise(Fn.col(c)),
)
return df
@@ -421,7 +421,7 @@ def __merge_equal_intervals(self, df: DataFrame) -> DataFrame:
"""
- merge_expr = tuple(f.max(c).alias(c) for c in self.metric_columns)
+ merge_expr = tuple(Fn.max(c).alias(c) for c in self.metric_columns)
return df.groupBy(*self.interval_boundaries, *self.series_ids).agg(*merge_expr)
@@ -467,7 +467,7 @@ def disjoint(self) -> "IntervalsDF":
(df, subset_indicator) = self.__identify_subset_intervals(df)
- subset_df = df.filter(f.col(subset_indicator))
+ subset_df = df.filter(Fn.col(subset_indicator))
subset_df = self.__merge_adjacent_subset_and_superset(
subset_df, subset_indicator
@@ -477,7 +477,7 @@ def disjoint(self) -> "IntervalsDF":
*self.interval_boundaries, *self.series_ids, *self.metric_columns
)
- non_subset_df = df.filter(~f.col(subset_indicator))
+ non_subset_df = df.filter(~Fn.col(subset_indicator))
(non_subset_df, overlap_indicators) = self.__identify_overlaps(non_subset_df)
@@ -610,7 +610,7 @@ def toDF(self, stack: bool = False) -> DataFrame:
)
return self.df.select(
- *self.interval_boundaries, *self.series_ids, f.expr(stack_expr)
+ *self.interval_boundaries, *self.series_ids, Fn.expr(stack_expr)
).dropna(subset="metric_value")
else:
diff --git a/python/tempo/io.py b/python/tempo/io.py
index 1399b590..838e0b67 100644
--- a/python/tempo/io.py
+++ b/python/tempo/io.py
@@ -4,7 +4,7 @@
import logging
from collections import deque
import tempo
-import pyspark.sql.functions as f
+import pyspark.sql.functions as Fn
from pyspark.sql import SparkSession
from pyspark.sql.utils import ParseException
@@ -35,9 +35,9 @@ def write(
useDeltaOpt = os.getenv("DATABRICKS_RUNTIME_VERSION") is not None
- view_df = df.withColumn("event_dt", f.to_date(f.col(ts_col))).withColumn(
+ view_df = df.withColumn("event_dt", Fn.to_date(Fn.col(ts_col))).withColumn(
"event_time",
- f.translate(f.split(f.col(ts_col).cast("string"), " ")[1], ":", "").cast(
+ Fn.translate(Fn.split(Fn.col(ts_col).cast("string"), " ")[1], ":", "").cast(
"double"
),
)
diff --git a/python/tempo/resample.py b/python/tempo/resample.py
index baae03a2..7acf1078 100644
--- a/python/tempo/resample.py
+++ b/python/tempo/resample.py
@@ -4,7 +4,7 @@
import tempo
-import pyspark.sql.functions as f
+import pyspark.sql.functions as Fn
from pyspark.sql.window import Window
from pyspark.sql import DataFrame
@@ -45,8 +45,8 @@ def _appendAggKey(tsdf: tempo.TSDF, freq: str = None):
"""
df = tsdf.df
parsed_freq = checkAllowableFreq(freq)
- agg_window = f.window(
- f.col(tsdf.ts_col), "{} {}".format(parsed_freq[0], freq_dict[parsed_freq[1]])
+ agg_window = Fn.window(
+ Fn.col(tsdf.ts_col), "{} {}".format(parsed_freq[0], freq_dict[parsed_freq[1]])
)
df = df.withColumn("agg_key", agg_window)
@@ -88,16 +88,16 @@ def aggregate(
else:
prefix = prefix + "_"
- groupingCols = [f.col(column) for column in groupingCols]
+ groupingCols = [Fn.col(column) for column in groupingCols]
if func == floor:
- metricCol = f.struct([tsdf.ts_col] + metricCols)
+ metricCol = Fn.struct([tsdf.ts_col] + metricCols)
res = df.withColumn("struct_cols", metricCol).groupBy(groupingCols)
- res = res.agg(f.min("struct_cols").alias("closest_data")).select(
- *groupingCols, f.col("closest_data.*")
+ res = res.agg(Fn.min("struct_cols").alias("closest_data")).select(
+ *groupingCols, Fn.col("closest_data.*")
)
- new_cols = [f.col(tsdf.ts_col)] + [
- f.col(c).alias("{}".format(prefix) + c) for c in metricCols
+ new_cols = [Fn.col(tsdf.ts_col)] + [
+ Fn.col(c).alias("{}".format(prefix) + c) for c in metricCols
]
res = res.select(*groupingCols, *new_cols)
elif func == average:
@@ -107,7 +107,7 @@ def aggregate(
set(res.columns).difference(set(tsdf.series_ids + [tsdf.ts_col, "agg_key"]))
)
new_cols = [
- f.col(c).alias("{}".format(prefix) + (c.split("avg(")[1]).replace(")", ""))
+ Fn.col(c).alias("{}".format(prefix) + (c.split("avg(")[1]).replace(")", ""))
for c in agg_metric_cls
]
res = res.select(*groupingCols, *new_cols)
@@ -118,7 +118,7 @@ def aggregate(
set(res.columns).difference(set(tsdf.series_ids + [tsdf.ts_col, "agg_key"]))
)
new_cols = [
- f.col(c).alias("{}".format(prefix) + (c.split("min(")[1]).replace(")", ""))
+ Fn.col(c).alias("{}".format(prefix) + (c.split("min(")[1]).replace(")", ""))
for c in agg_metric_cls
]
res = res.select(*groupingCols, *new_cols)
@@ -129,18 +129,18 @@ def aggregate(
set(res.columns).difference(set(tsdf.series_ids + [tsdf.ts_col, "agg_key"]))
)
new_cols = [
- f.col(c).alias("{}".format(prefix) + (c.split("max(")[1]).replace(")", ""))
+ Fn.col(c).alias("{}".format(prefix) + (c.split("max(")[1]).replace(")", ""))
for c in agg_metric_cls
]
res = res.select(*groupingCols, *new_cols)
elif func == ceiling:
- metricCol = f.struct([tsdf.ts_col] + metricCols)
+ metricCol = Fn.struct([tsdf.ts_col] + metricCols)
res = df.withColumn("struct_cols", metricCol).groupBy(groupingCols)
- res = res.agg(f.max("struct_cols").alias("ceil_data")).select(
- *groupingCols, f.col("ceil_data.*")
+ res = res.agg(Fn.max("struct_cols").alias("ceil_data")).select(
+ *groupingCols, Fn.col("ceil_data.*")
)
- new_cols = [f.col(tsdf.ts_col)] + [
- f.col(c).alias("{}".format(prefix) + c) for c in metricCols
+ new_cols = [Fn.col(tsdf.ts_col)] + [
+ Fn.col(c).alias("{}".format(prefix) + c) for c in metricCols
]
res = res.select(*groupingCols, *new_cols)
@@ -148,7 +148,7 @@ def aggregate(
res = (
res.drop(tsdf.ts_col)
.withColumnRenamed("agg_key", tsdf.ts_col)
- .withColumn(tsdf.ts_col, f.col(tsdf.ts_col).start)
+ .withColumn(tsdf.ts_col, Fn.col(tsdf.ts_col).start)
)
# sort columns so they are consistent
@@ -161,14 +161,14 @@ def aggregate(
imputes = (
res.select(
*tsdf.series_ids,
- f.min(tsdf.ts_col).over(fillW).alias("from"),
- f.max(tsdf.ts_col).over(fillW).alias("until"),
+ Fn.min(tsdf.ts_col).over(fillW).alias("from"),
+ Fn.max(tsdf.ts_col).over(fillW).alias("until"),
)
.distinct()
.withColumn(
tsdf.ts_col,
- f.explode(
- f.expr("sequence(from, until, interval {} {})".format(period, unit))
+ Fn.explode(
+ Fn.expr("sequence(from, until, interval {} {})".format(period, unit))
),
)
.drop("from", "until")
diff --git a/python/tests/base.py b/python/tests/base.py
index ae2c03ef..e8bc52f1 100644
--- a/python/tests/base.py
+++ b/python/tests/base.py
@@ -4,7 +4,7 @@
from typing import Union
import jsonref
-import pyspark.sql.functions as F
+import pyspark.sql.functions as Fn
from chispa import assert_df_equality
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame
@@ -162,11 +162,11 @@ def buildTestDF(self, schema, data, ts_cols=["event_ts"]):
# we're changing a field nested in a struct
(struct, field) = tsc.split(".")
df = df.withColumn(
- struct, F.col(struct).withField(field, F.to_timestamp(tsc))
+ struct, Fn.col(struct).withField(field, Fn.to_timestamp(tsc))
)
else:
# standard column
- df = df.withColumn(tsc, F.to_timestamp(F.col(tsc)))
+ df = df.withColumn(tsc, Fn.to_timestamp(Fn.col(tsc)))
return df
#
diff --git a/python/tests/intervals_tests.py b/python/tests/intervals_tests.py
index 841b4dd2..379df6ca 100644
--- a/python/tests/intervals_tests.py
+++ b/python/tests/intervals_tests.py
@@ -1,7 +1,7 @@
from tempo.intervals import *
from tests.tsdf_tests import SparkTest
from pyspark.sql.utils import AnalysisException
-import pyspark.sql.functions as f
+import pyspark.sql.functions as Fn
class IntervalsDFTests(SparkTest):
@@ -141,8 +141,8 @@ def test_fromStackedMetrics_series_list(self):
idf_expected = self.get_data_as_idf("expected")
df_input = df_input.withColumn(
- "start_ts", f.to_timestamp("start_ts")
- ).withColumn("end_ts", f.to_timestamp("end_ts"))
+ "start_ts", Fn.to_timestamp("start_ts")
+ ).withColumn("end_ts", Fn.to_timestamp("end_ts"))
idf = IntervalsDF.fromStackedMetrics(
df_input,
@@ -162,8 +162,8 @@ def test_fromStackedMetrics_metric_names(self):
idf_expected = self.get_data_as_idf("expected")
df_input = df_input.withColumn(
- "start_ts", f.to_timestamp("start_ts")
- ).withColumn("end_ts", f.to_timestamp("end_ts"))
+ "start_ts", Fn.to_timestamp("start_ts")
+ ).withColumn("end_ts", Fn.to_timestamp("end_ts"))
idf = IntervalsDF.fromStackedMetrics(
df_input,
@@ -336,8 +336,8 @@ def test_toDF_stack(self):
expected_df = self.get_data_as_sdf("expected")
expected_df = expected_df.withColumn(
- "start_ts", f.to_timestamp("start_ts")
- ).withColumn("end_ts", f.to_timestamp("end_ts"))
+ "start_ts", Fn.to_timestamp("start_ts")
+ ).withColumn("end_ts", Fn.to_timestamp("end_ts"))
actual_df = idf_input.toDF(stack=True)
diff --git a/python/tests/tsdf_tests.py b/python/tests/tsdf_tests.py
index 4695492b..56098f7b 100644
--- a/python/tests/tsdf_tests.py
+++ b/python/tests/tsdf_tests.py
@@ -7,7 +7,7 @@
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.window import WindowSpec
-import pyspark.sql.functions as f
+import pyspark.sql.functions as Fn
from tempo.tsdf import TSDF
from tests.base import SparkTest
@@ -34,20 +34,20 @@ def test_describe(self):
# self.assertDataFrameEquality(res, dfExpected)
assert res.count() == 7
assert (
- res.filter(f.col("unique_time_series_count") != " ")
- .select(f.max(f.col("unique_time_series_count")))
+ res.filter(Fn.col("unique_time_series_count") != " ")
+ .select(Fn.max(Fn.col("unique_time_series_count")))
.collect()[0][0]
== "1"
)
assert (
- res.filter(f.col("min_ts") != " ")
- .select(f.col("min_ts").cast("string"))
+ res.filter(Fn.col("min_ts") != " ")
+ .select(Fn.col("min_ts").cast("string"))
.collect()[0][0]
== "2020-08-01 00:00:10"
)
assert (
- res.filter(f.col("max_ts") != " ")
- .select(f.col("max_ts").cast("string"))
+ res.filter(Fn.col("max_ts") != " ")
+ .select(Fn.col("max_ts").cast("string"))
.collect()[0][0]
== "2020-09-01 00:19:12"
)
@@ -68,7 +68,7 @@ def __timestamp_to_double(ts: str) -> float:
@staticmethod
def __tsdf_with_double_tscol(tsdf: TSDF) -> TSDF:
with_double_tscol_df = tsdf.df.withColumn(
- tsdf.ts_col, f.col(tsdf.ts_col).cast("double")
+ tsdf.ts_col, Fn.col(tsdf.ts_col).cast("double")
)
return TSDF(
with_double_tscol_df, ts_col=tsdf.ts_col, series_ids=tsdf.series_ids
@@ -856,28 +856,28 @@ def test_range_stats(self):
# cast to decimal with precision in cents for simplicity
featured_df = featured_df.select(
- f.col("symbol"),
- f.col("event_ts"),
- f.col("mean_trade_pr").cast("decimal(5, 2)"),
- f.col("count_trade_pr"),
- f.col("min_trade_pr").cast("decimal(5,2)"),
- f.col("max_trade_pr").cast("decimal(5,2)"),
- f.col("sum_trade_pr").cast("decimal(5,2)"),
- f.col("stddev_trade_pr").cast("decimal(5,2)"),
- f.col("zscore_trade_pr").cast("decimal(5,2)"),
+ Fn.col("symbol"),
+ Fn.col("event_ts"),
+ Fn.col("mean_trade_pr").cast("decimal(5, 2)"),
+ Fn.col("count_trade_pr"),
+ Fn.col("min_trade_pr").cast("decimal(5,2)"),
+ Fn.col("max_trade_pr").cast("decimal(5,2)"),
+ Fn.col("sum_trade_pr").cast("decimal(5,2)"),
+ Fn.col("stddev_trade_pr").cast("decimal(5,2)"),
+ Fn.col("zscore_trade_pr").cast("decimal(5,2)"),
)
# cast to decimal with precision in cents for simplicity
dfExpected = dfExpected.select(
- f.col("symbol"),
- f.col("event_ts"),
- f.col("mean_trade_pr").cast("decimal(5, 2)"),
- f.col("count_trade_pr"),
- f.col("min_trade_pr").cast("decimal(5,2)"),
- f.col("max_trade_pr").cast("decimal(5,2)"),
- f.col("sum_trade_pr").cast("decimal(5,2)"),
- f.col("stddev_trade_pr").cast("decimal(5,2)"),
- f.col("zscore_trade_pr").cast("decimal(5,2)"),
+ Fn.col("symbol"),
+ Fn.col("event_ts"),
+ Fn.col("mean_trade_pr").cast("decimal(5, 2)"),
+ Fn.col("count_trade_pr"),
+ Fn.col("min_trade_pr").cast("decimal(5,2)"),
+ Fn.col("max_trade_pr").cast("decimal(5,2)"),
+ Fn.col("sum_trade_pr").cast("decimal(5,2)"),
+ Fn.col("stddev_trade_pr").cast("decimal(5,2)"),
+ Fn.col("zscore_trade_pr").cast("decimal(5,2)"),
)
# should be equal to the expected dataframe
@@ -895,26 +895,26 @@ def test_group_stats(self):
# cast to decimal with precision in cents for simplicity
featured_df = featured_df.select(
- f.col("symbol"),
- f.col("event_ts"),
- f.col("mean_trade_pr").cast("decimal(5, 2)"),
- f.col("count_trade_pr"),
- f.col("min_trade_pr").cast("decimal(5,2)"),
- f.col("max_trade_pr").cast("decimal(5,2)"),
- f.col("sum_trade_pr").cast("decimal(5,2)"),
- f.col("stddev_trade_pr").cast("decimal(5,2)"),
+ Fn.col("symbol"),
+ Fn.col("event_ts"),
+ Fn.col("mean_trade_pr").cast("decimal(5, 2)"),
+ Fn.col("count_trade_pr"),
+ Fn.col("min_trade_pr").cast("decimal(5,2)"),
+ Fn.col("max_trade_pr").cast("decimal(5,2)"),
+ Fn.col("sum_trade_pr").cast("decimal(5,2)"),
+ Fn.col("stddev_trade_pr").cast("decimal(5,2)"),
)
# cast to decimal with precision in cents for simplicity
dfExpected = dfExpected.select(
- f.col("symbol"),
- f.col("event_ts"),
- f.col("mean_trade_pr").cast("decimal(5, 2)"),
- f.col("count_trade_pr"),
- f.col("min_trade_pr").cast("decimal(5,2)"),
- f.col("max_trade_pr").cast("decimal(5,2)"),
- f.col("sum_trade_pr").cast("decimal(5,2)"),
- f.col("stddev_trade_pr").cast("decimal(5,2)"),
+ Fn.col("symbol"),
+ Fn.col("event_ts"),
+ Fn.col("mean_trade_pr").cast("decimal(5, 2)"),
+ Fn.col("count_trade_pr"),
+ Fn.col("min_trade_pr").cast("decimal(5,2)"),
+ Fn.col("max_trade_pr").cast("decimal(5,2)"),
+ Fn.col("sum_trade_pr").cast("decimal(5,2)"),
+ Fn.col("stddev_trade_pr").cast("decimal(5,2)"),
)
# should be equal to the expected dataframe
@@ -935,7 +935,7 @@ def test_resample(self):
featured_df = tsdf_input.resample(freq="min", func="floor", prefix="floor").df
# 30 minute aggregation
resample_30m = tsdf_input.resample(freq="5 minutes", func="mean").df.withColumn(
- "trade_pr", f.round(f.col("trade_pr"), 2)
+ "trade_pr", Fn.round(Fn.col("trade_pr"), 2)
)
bars = tsdf_input.calc_bars(
@@ -958,7 +958,7 @@ def test_resample_millis(self):
# 30 minute aggregation
resample_ms = tsdf_init.resample(freq="ms", func="mean").df.withColumn(
- "trade_pr", f.round(f.col("trade_pr"), 2)
+ "trade_pr", Fn.round(Fn.col("trade_pr"), 2)
)
self.assertDataFrameEquality(resample_ms, dfExpected)
@@ -973,14 +973,14 @@ def test_upsample(self):
resample_30m = tsdf_input.resample(
freq="5 minutes", func="mean", fill=True
- ).df.withColumn("trade_pr", f.round(f.col("trade_pr"), 2))
+ ).df.withColumn("trade_pr", Fn.round(Fn.col("trade_pr"), 2))
bars = tsdf_input.calc_bars(
freq="min", metricCols=["trade_pr", "trade_pr_2"]
).df
upsampled = resample_30m.filter(
- f.col("event_ts").isin(
+ Fn.col("event_ts").isin(
"2020-08-01 00:00:00",
"2020-08-01 00:05:00",
"2020-09-01 00:00:00",
@@ -1173,7 +1173,7 @@ def test_threshold_fn(self):
# threshold state function
def threshold_fn(a: Column, b: Column) -> Column:
- return f.abs(a - b) < f.lit(0.5)
+ return Fn.abs(a - b) < Fn.lit(0.5)
# call extractStateIntervals method
extracted_intervals_df: DataFrame = input_tsdf.extractStateIntervals(
From b8e8f8e51f958c062fed10024459ad4028be8a11 Mon Sep 17 00:00:00 2001
From: Tristan Nixon
Date: Fri, 24 Feb 2023 15:50:35 -0800
Subject: [PATCH 11/11] committing WIP - migrating to new laptop
---
examples/dlt_tempo.py | 2 +-
examples/financial_services_quickstart.py | 6 +-
python/tempo/tsdf.py | 64 ++---
python/tempo/tsschema.py | 278 ++++++++++++++--------
python/tests/base.py | 8 +-
python/tests/interpol_tests.py | 7 +-
python/tests/tsdf_tests.py | 4 +-
7 files changed, 215 insertions(+), 154 deletions(-)
diff --git a/examples/dlt_tempo.py b/examples/dlt_tempo.py
index dba0824b..f532a9f8 100644
--- a/examples/dlt_tempo.py
+++ b/examples/dlt_tempo.py
@@ -26,7 +26,7 @@ def ts_bronze():
@dlt.expect_or_drop("User_check","User in ('a','c','i')")
def ts_ft():
phone_accel_df = dlt.read("ts_bronze")
- phone_accel_tsdf = TSDF(phone_accel_df, ts_col="event_ts", partition_cols = ["User"])
+ phone_accel_tsdf = TSDF(phone_accel_df, ts_col="event_ts")
ts_ft_df = phone_accel_tsdf.fourier_transform(timestep=1, valueCol="x").df
return ts_ft_df
diff --git a/examples/financial_services_quickstart.py b/examples/financial_services_quickstart.py
index d515bd97..ee638170 100644
--- a/examples/financial_services_quickstart.py
+++ b/examples/financial_services_quickstart.py
@@ -92,8 +92,8 @@
# DBTITLE 1,Define TSDF Time Series Data Structure
from tempo import *
-trades_tsdf = TSDF(trades_df, partition_cols = ['date', 'symbol'], ts_col = 'event_ts')
-quotes_tsdf = TSDF(quotes_df, partition_cols = ['date', 'symbol'], ts_col = 'event_ts')
+trades_tsdf = TSDF(trades_df, ts_col='event_ts')
+quotes_tsdf = TSDF(quotes_df, ts_col='event_ts')
# COMMAND ----------
@@ -178,7 +178,7 @@
from tempo import *
from pyspark.sql.functions import *
-minute_bars = TSDF(spark.table("time_test"), partition_cols=['ticker'], ts_col="ts").calc_bars(freq = '1 minute', func= 'ceil')
+minute_bars = TSDF(spark.table("time_test"), ts_col="ts").calc_bars(freq ='1 minute', func='ceil')
display(minute_bars)
diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py
index 8bb118f6..8b6b51e2 100644
--- a/python/tempo/tsdf.py
+++ b/python/tempo/tsdf.py
@@ -4,7 +4,7 @@
import operator
from copy import deepcopy
from functools import reduce, cached_property
-from typing import List, Union, Callable, Collection
+from typing import cast, List, Union, Callable, Collection
import numpy as np
import pyspark.sql.functions as Fn
@@ -19,7 +19,7 @@
import tempo.io as tio
import tempo.resample as rs
from tempo.interpol import Interpolation
-from tempo.tsschema import TSIndex, TSSchema, SubSequenceTSIndex
+from tempo.tsschema import TSIndex, TSSchema, CompositeTSIndex
from tempo.utils import (
ENV_CAN_RENDER_HTML,
IS_DATABRICKS,
@@ -62,14 +62,11 @@ class TSDF:
This object is the main wrapper over a Spark data frame which allows a user to parallelize time series computations on a Spark data frame by various dimensions. The two dimensions required are partition_cols (list of columns by which to summarize) and ts_col (timestamp column, which can be epoch or TimestampType).
"""
- def __init__(
- self,
- df: DataFrame,
- ts_schema: TSSchema = None,
- ts_col: str = None,
- series_ids: Collection[str] = None,
- validate_schema=True,
- ) -> None:
+ def __init__(self,
+ df: DataFrame,
+ ts_schema: TSSchema = None,
+ ts_col: str = None,
+ series_ids: Collection[str] = None) -> None:
self.df = df
# construct schema if we don't already have one
if ts_schema:
@@ -77,8 +74,7 @@ def __init__(
else:
self.ts_schema = TSSchema.fromDFSchema(self.df.schema, ts_col, series_ids)
# validate that this schema works for this DataFrame
- if validate_schema:
- self.ts_schema.validate(df.schema)
+ self.ts_schema.validate(df.schema)
def __repr__(self) -> str:
return self.__str__()
@@ -98,7 +94,7 @@ def __withTransformedDF(self, new_df: DataFrame) -> "TSDF":
:return: a new TSDF object with the transformed DataFrame
"""
- return TSDF(new_df, ts_schema=deepcopy(self.ts_schema), validate_schema=False)
+ return TSDF(new_df, ts_schema=deepcopy(self.ts_schema))
def __withStandardizedColOrder(self) -> TSDF:
"""
@@ -110,7 +106,7 @@ def __withStandardizedColOrder(self) -> TSDF:
:return: a :class:`TSDF` with the columns reordered into "standard order" (as described above)
"""
std_ordered_cols = (
- list(self.series_ids) + [self.ts_index.name] + list(self.observational_cols)
+ list(self.series_ids) + [self.ts_index.colname] + list(self.observational_cols)
)
return self.__withTransformedDF(self.df.select(std_ordered_cols))
@@ -150,7 +146,7 @@ def fromSubsequenceCol(
)
# construct an appropriate TSIndex
subseq_struct = with_subseq_struct_df.schema[struct_col_name]
- subseq_idx = SubSequenceTSIndex(subseq_struct, ts_col, subsequence_col)
+ subseq_idx = CompositeTSIndex(subseq_struct, ts_col, subsequence_col)
# construct & return the TSDF with appropriate schema
return TSDF(with_subseq_struct_df, ts_schema=TSSchema(subseq_idx, series_ids))
@@ -453,15 +449,13 @@ def __getTimePartitions(self, tsPartitionVal, fraction=0.1):
df = partition_df.union(remainder_df).drop(
"partition_remainder", "ts_col_double"
)
- return TSDF(
- df, ts_col=self.ts_col, series_ids=self.series_ids + ["ts_partition"]
- )
+ return TSDF(df, ts_col=self.ts_col, series_ids=self.series_ids + ["ts_partition"])
#
# Slicing & Selection
#
- def select(self, *cols):
+ def select(self, *cols: Union[str, Column]) -> TSDF:
"""
pyspark.sql.DataFrame.select() method's equivalent for TSDF objects
Parameters
@@ -480,12 +474,8 @@ def select(self, *cols):
"""
# The columns which will be a mandatory requirement while selecting from TSDFs
- if set(self.structural_cols).issubset(set(cols)):
- return self.__withTransformedDF(self.df.select(*cols))
- else:
- raise TSDFStructureChangeError(
- "select that does not include all structural columns"
- )
+ selected_df = self.df.select(*cols)
+ return self.__withTransformedDF(selected_df)
def __slice(self, op: str, target_ts):
"""
@@ -801,7 +791,7 @@ def asofJoin(
Performs an as-of join between two time-series. If a tsPartitionVal is specified, it will do this partitioned by
time brackets, which can help alleviate skew.
- NOTE: partition cols have to be the same for both Dataframes. We are collecting stats when the WARNING level is
+ NOTE: Series IDs have to be the same for both Dataframes. We are collecting stats when the WARNING level is
enabled also.
Parameters
@@ -875,7 +865,7 @@ def asofJoin(
)
.drop("lead_" + right_tsdf.ts_col)
)
- return TSDF(res, series_ids=self.series_ids, ts_col=new_left_ts_col)
+ return TSDF(res, ts_col=new_left_ts_col, series_ids=self.series_ids)
# end of block checking to see if standard Spark SQL join will work
@@ -931,8 +921,8 @@ def asofJoin(
# perform asof join.
if tsPartitionVal is None:
seq_col = None
- if isinstance(combined_df.ts_index, SubSequenceTSIndex):
- seq_col = combined_df.ts_index.sub_seq_col
+ if isinstance(combined_df.ts_index, CompositeTSIndex):
+ seq_col = cast(CompositeTSIndex, combined_df.ts_index).ts_component(1)
asofDF = combined_df.__getLastRightRow(
left_tsdf.ts_col,
right_columns,
@@ -946,8 +936,8 @@ def asofJoin(
tsPartitionVal, fraction=fraction
)
seq_col = None
- if isinstance(tsPartitionDF.ts_index, SubSequenceTSIndex):
- seq_col = tsPartitionDF.ts_index.sub_seq_col
+ if isinstance(tsPartitionDF.ts_index, CompositeTSIndex):
+ seq_col = cast(CompositeTSIndex, tsPartitionDF.ts_index).ts_component(1)
asofDF = tsPartitionDF.__getLastRightRow(
left_tsdf.ts_col,
right_columns,
@@ -981,7 +971,7 @@ def __rowsBetweenWindow(self, rows_from, rows_to, reverse=False):
def __rangeBetweenWindow(self, range_from, range_to, reverse=False):
return (
self.__baseWindow(reverse=reverse)
- .orderBy(self.ts_index.rangeOrderByExpr(reverse=reverse))
+ .orderBy(self.ts_index.rangeExpr(reverse=reverse))
.rangeBetween(range_from, range_to)
)
@@ -1006,10 +996,6 @@ def withColumn(self, colName: str, col: Column) -> "TSDF":
:param colName: the name of the new column (or existing column to be replaced)
:param col: a :class:`Column` expression for the new column definition
"""
- if colName in self.structural_cols:
- raise TSDFStructureChangeError(
- f"withColumn on the structural column {colName}."
- )
new_df = self.df.withColumn(colName, col)
return self.__withTransformedDF(new_df)
@@ -1023,7 +1009,7 @@ def withColumnRenamed(self, existing: str, new: str) -> "TSDF":
# create new TSIndex
new_ts_index = deepcopy(self.ts_index)
- if existing == self.ts_index.name:
+ if existing == self.ts_index.colname:
new_ts_index = new_ts_index.renamed(new)
# and for series ids
@@ -1390,9 +1376,7 @@ def calc_bars(tsdf, freq, func=None, metricCols=None, fill=None):
)
bars = bars.select(sel_and_sort)
- return TSDF(
- bars, ts_col=resample_open.ts_col, series_ids=resample_open.series_ids
- )
+ return TSDF(bars, ts_col=resample_open.ts_col, series_ids=resample_open.series_ids)
def fourier_transform(self, timestep, valueCol):
"""
diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py
index 51b019da..45aa4b6b 100644
--- a/python/tempo/tsschema.py
+++ b/python/tempo/tsschema.py
@@ -1,11 +1,26 @@
+from enum import Enum, auto
from abc import ABC, abstractmethod
-from typing import Any, Union, Collection, List
+from typing import cast, Any, Union, Optional, Collection, List
import pyspark.sql.functions as Fn
from pyspark.sql import Column
from pyspark.sql.types import *
from pyspark.sql.types import NumericType
+#
+# Time Units
+#
+
+class TimeUnits(Enum):
+ YEARS = auto()
+ MONTHS = auto()
+ DAYS = auto()
+ HOURS = auto()
+ MINUTES = auto()
+ SECONDS = auto()
+ MICROSECONDS = auto()
+ NANOSECONDS = auto()
+
#
# Timeseries Index Classes
@@ -21,24 +36,24 @@ def __eq__(self, o: object) -> bool:
# must be a SimpleTSIndex
if not isinstance(o, TSIndex):
return False
- return self.indexAttributes == o.indexAttributes
+ return self._indexAttributes == o._indexAttributes
def __repr__(self) -> str:
return self.__str__()
def __str__(self) -> str:
- return f"""{self.__class__.__name__}({self.indexAttributes})"""
+ return f"""{self.__class__.__name__}({self._indexAttributes})"""
@property
@abstractmethod
- def indexAttributes(self) -> dict[str, Any]:
+ def _indexAttributes(self) -> dict[str, Any]:
"""
:return: key attributes of this index
"""
@property
@abstractmethod
- def name(self) -> str:
+ def colname(self) -> str:
"""
:return: the column name of the timeseries index
"""
@@ -50,6 +65,20 @@ def ts_col(self) -> str:
:return: the name of the primary timeseries column (may or may not be the same as the name)
"""
+ @property
+ @abstractmethod
+ def unit(self) -> Optional[TimeUnits]:
+ """
+ :return: the unit of this index, that is, the unit that a range value of 1 represents (Days, seconds, etc.)
+ """
+
+ @abstractmethod
+ def validate(self, df_schema: StructType) -> None:
+ """
+ Validate that this TSIndex is correctly represented in the given schema
+ :param df_schema: the schema for a :class:`DataFrame`
+ """
+
@abstractmethod
def renamed(self, new_name: str) -> "TSIndex":
"""
@@ -69,6 +98,8 @@ def _reverseOrNot(
return expr.desc() # reverse a single-expression
elif type(expr) == List[Column]:
return [col.desc() for col in expr] # reverse all columns in the expression
+ else:
+ raise TypeError(f"Type for expr argument must be either Column or List[Column], instead received: {type(expr)}")
@abstractmethod
def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
@@ -80,16 +111,15 @@ def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
:return: an expression appropriate for ordering the :class:`TSDF` according to this index
"""
- def rangeOrderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
+ @abstractmethod
+ def rangeExpr(self, reverse: bool = False) -> Column:
"""
Gets an expression appropriate for performing range operations on the :class:`TSDF` records.
- Defaults to the same expression giving by :method:`TSIndex.orderByExpr`
:param reverse: whether the ordering should be reversed (backwards in time)
:return: an expression appropriate for operforming range operations on the :class:`TSDF` records
"""
- return self.orderByExpr(reverse=reverse)
#
@@ -103,28 +133,38 @@ class SimpleTSIndex(TSIndex, ABC):
that only reference a single column for maintaining the temporal structure
"""
- def __init__(self, ts_col: StructField) -> None:
- self.__name = ts_col.name
- self.dataType = ts_col.dataType
+ def __init__(self, ts_idx: StructField) -> None:
+ self.__name = ts_idx.name
+ self.dataType = ts_idx.dataType
@property
- def indexAttributes(self) -> dict[str, Any]:
- return {"name": self.name, "dataType": self.dataType}
+ def _indexAttributes(self) -> dict[str, Any]:
+ return {"name": self.colname, "dataType": self.dataType}
@property
- def name(self):
+ def colname(self):
return self.__name
@property
def ts_col(self) -> str:
- return self.name
+ return self.colname
+
+ def validate(self, df_schema: StructType) -> None:
+ # the ts column must exist
+ assert(self.colname in df_schema.fieldNames(),
+ f"The TSIndex column {self.colname} does not exist in the given DataFrame")
+ schema_ts_col = df_schema[self.colname]
+ # it must have the right type
+ schema_ts_type = schema_ts_col.dataType
+ assert( isinstance(schema_ts_type, type(self.dataType)),
+ f"The TSIndex column is of type {schema_ts_type}, but the expected type is {self.dataType}" )
def renamed(self, new_name: str) -> "TSIndex":
self.__name = new_name
return self
def orderByExpr(self, reverse: bool = False) -> Column:
- expr = Fn.col(self.name)
+ expr = Fn.col(self.colname)
return self._reverseOrNot(expr, reverse)
@classmethod
@@ -147,12 +187,19 @@ class NumericIndex(SimpleTSIndex):
Timeseries index based on a single column of a numeric or temporal type.
"""
- def __init__(self, ts_col: StructField) -> None:
- if not isinstance(ts_col.dataType, NumericType):
+ def __init__(self, ts_idx: StructField) -> None:
+ if not isinstance(ts_idx.dataType, NumericType):
raise TypeError(
- f"NumericIndex must be of a numeric type, but ts_col {ts_col.name} has type {ts_col.dataType}"
+ f"NumericIndex must be of a numeric type, but ts_col {ts_idx.name} has type {ts_idx.dataType}"
)
- super().__init__(ts_col)
+ super().__init__(ts_idx)
+
+ @property
+ def unit(self) -> Optional[TimeUnits]:
+ return None
+
+ def rangeExpr(self, reverse: bool = False) -> Column:
+ return self.orderByExpr(reverse)
class SimpleTimestampIndex(SimpleTSIndex):
@@ -160,16 +207,20 @@ class SimpleTimestampIndex(SimpleTSIndex):
Timeseries index based on a single Timestamp column
"""
- def __init__(self, ts_col: StructField) -> None:
- if not isinstance(ts_col.dataType, TimestampType):
+ def __init__(self, ts_idx: StructField) -> None:
+ if not isinstance(ts_idx.dataType, TimestampType):
raise TypeError(
- f"SimpleTimestampIndex must be of TimestampType, but given ts_col {ts_col.name} has type {ts_col.dataType}"
+ f"SimpleTimestampIndex must be of TimestampType, but given ts_col {ts_idx.name} has type {ts_idx.dataType}"
)
- super().__init__(ts_col)
+ super().__init__(ts_idx)
- def rangeOrderByExpr(self, reverse: bool = False) -> Column:
+ @property
+ def unit(self) -> Optional[TimeUnits]:
+ return TimeUnits.SECONDS
+
+ def rangeExpr(self, reverse: bool = False) -> Column:
# cast timestamp to double (fractional seconds since epoch)
- expr = Fn.col(self.name).cast("double")
+ expr = Fn.col(self.colname).cast("double")
return self._reverseOrNot(expr, reverse)
@@ -178,52 +229,58 @@ class SimpleDateIndex(SimpleTSIndex):
Timeseries index based on a single Date column
"""
- def __init__(self, ts_col: StructField) -> None:
- if not isinstance(ts_col.dataType, DateType):
+ def __init__(self, ts_idx: StructField) -> None:
+ if not isinstance(ts_idx.dataType, DateType):
raise TypeError(
- f"DateIndex must be of DateType, but given ts_col {ts_col.name} has type {ts_col.dataType}"
+ f"DateIndex must be of DateType, but given ts_col {ts_idx.name} has type {ts_idx.dataType}"
)
- super().__init__(ts_col)
+ super().__init__(ts_idx)
+
+ @property
+ def unit(self) -> Optional[TimeUnits]:
+ return TimeUnits.DAYS
- def rangeOrderByExpr(self, reverse: bool = False) -> Column:
+ def rangeExpr(self, reverse: bool = False) -> Column:
# convert date to number of days since the epoch
- expr = Fn.datediff(Fn.col(self.name), Fn.lit("1970-01-01").cast("date"))
+ expr = Fn.datediff(Fn.col(self.colname), Fn.lit("1970-01-01").cast("date"))
return self._reverseOrNot(expr, reverse)
#
-# Compound TS Index Types
+# Complex (Multi-Field) TS Index Types
#
-class CompositeTSIndex(TSIndex, ABC):
+class CompositeTSIndex(TSIndex):
"""
Abstract base class for complex Timeseries Index classes
that involve two or more columns organized into a StructType column
"""
- def __init__(self, composite_ts_idx: StructField, primary_ts_col: str) -> None:
- if not isinstance(composite_ts_idx.dataType, StructType):
+ def __init__(self, ts_idx: StructField, *ts_fields: str) -> None:
+ if not isinstance(ts_idx.dataType, StructType):
raise TypeError(
- f"CompoundTSIndex must be of type StructType, but given compound_ts_idx {composite_ts_idx.name} has type {composite_ts_idx.dataType}"
+ f"CompoundTSIndex must be of type StructType, but given compound_ts_idx {ts_idx.name} has type {ts_idx.dataType}"
)
- self.__name: str = composite_ts_idx.name
- self.struct: StructType = composite_ts_idx.dataType
- # construct a simple TS index object for the primary column
- self.primary_ts_idx: SimpleTSIndex = SimpleTSIndex.fromTSCol(
- self.struct[primary_ts_col]
- )
+ self.__name: str = ts_idx.name
+ self.struct: StructType = ts_idx.dataType
+ # handle the timestamp fields
+ if ts_fields is None or len(ts_fields) < 1:
+ raise ValueError("A CompoundTSIndex must have at least one ts_field specified!")
+ self.ts_components = [SimpleTSIndex.fromTSCol(self.struct[field]) for field in ts_fields]
+ self.primary_ts_idx = self.ts_components[0]
+
@property
- def indexAttributes(self) -> dict[str, Any]:
+ def _indexAttributes(self) -> dict[str, Any]:
return {
- "name": self.name,
+ "name": self.colname,
"struct": self.struct,
- "primary_ts_col": self.primary_ts_idx,
+ "ts_components": self.ts_components
}
@property
- def name(self) -> str:
+ def colname(self) -> str:
return self.__name
@property
@@ -232,56 +289,57 @@ def ts_col(self) -> str:
@property
def primary_ts_col(self) -> str:
- return self.component(self.primary_ts_idx.name)
+ return self.ts_component(0)
+
+ @property
+ def unit(self) -> Optional[TimeUnits]:
+ return self.primary_ts_idx.unit
+
+ def validate(self, df_schema: StructType) -> None:
+ # validate that the composite field exists
+ assert(self.colname in df_schema.fieldNames(),
+ f"The TSIndex column {self.colname} does not exist in the given DataFrame")
+ schema_ts_col = df_schema[self.colname]
+ # it must have the right type
+ schema_ts_type = schema_ts_col.dataType
+ assert( isinstance(schema_ts_type, StructType),
+ f"The TSIndex column is of type {schema_ts_type}, but the expected type is {StructType}" )
+ # validate all the TS components
+ for comp in self.ts_components:
+ comp.validate(schema_ts_type)
def renamed(self, new_name: str) -> "TSIndex":
self.__name = new_name
return self
- def component(self, component_name):
+ def component(self, component_name: str) -> str:
"""
- Returns the full path to a component column that is within the composite index
+ Returns the full path to a component field that is within the composite index
:param component_name: the name of the component element within the composite index
- :return: a column name that can be used to reference the component column from the :class:`TSDF`
+ :return: a column name that can be used to reference the component field in PySpark expressions
"""
- return f"{self.name}.{self.struct[component_name].name}"
-
- def orderByExpr(self, reverse: bool = False) -> Column:
- # default to using the primary column
- expr = Fn.col(self.primary_ts_col)
- return self._reverseOrNot(expr, reverse)
+ return f"{self.colname}.{self.struct[component_name].name}"
+ def ts_component(self, component_index: int) -> str:
+ """
+ Returns the full path to a component field that is a functional part of the timeseries.
-class SubSequenceTSIndex(CompositeTSIndex):
- """
- Timeseries Index when we have a primary timeseries column and a secondary sequencing
- column that indicates the
- """
-
- def __init__(
- self, composite_ts_idx: StructField, primary_ts_col: str, sub_seq_col: str
- ) -> None:
- super().__init__(composite_ts_idx, primary_ts_col)
- # construct a simple index for the sub-sequence column
- self.sub_sequence_idx = NumericIndex(self.struct[sub_seq_col])
-
- @property
- def indexAttributes(self) -> dict[str, Any]:
- attrs = super().indexAttributes
- attrs["sub_sequence_idx"] = self.sub_sequence_idx
- return attrs
+ :param component_index: the index giving the ordering of the component field within the timeseries
- @property
- def sub_seq_col(self) -> str:
- return self.component(self.sub_sequence_idx.name)
+ :return: a column name that can be used to reference the component field in PySpark expressions
+ """
+ return self.component(self.ts_components[component_index].colname)
- def orderByExpr(self, reverse: bool = False) -> List[Column]:
- # build a composite expression of the primary index followed by the sub-sequence index
- exprs = [Fn.col(self.primary_ts_col), Fn.col(self.sub_seq_col)]
+ def orderByExpr(self, reverse: bool = False) -> Column:
+ # build an expression for each TS component, in order
+ exprs = [Fn.col(self.component(comp.colname)) for comp in self.ts_components]
return self._reverseOrNot(exprs, reverse)
+ def rangeExpr(self, reverse: bool = False) -> Column:
+ return self.primary_ts_idx.rangeExpr(reverse)
+
class ParsedTSIndex(CompositeTSIndex, ABC):
"""
@@ -290,9 +348,9 @@ class ParsedTSIndex(CompositeTSIndex, ABC):
"""
def __init__(
- self, composite_ts_idx: StructField, src_str_col: str, parsed_col: str
+ self, ts_idx: StructField, src_str_col: str, parsed_col: str
) -> None:
- super().__init__(composite_ts_idx, primary_ts_col=parsed_col)
+ super().__init__(ts_idx, primary_ts_col=parsed_col)
src_str_field = self.struct[src_str_col]
if not isinstance(src_str_field.dataType, StringType):
raise TypeError(
@@ -301,8 +359,8 @@ def __init__(
self.__src_str_col = src_str_col
@property
- def indexAttributes(self) -> dict[str, Any]:
- attrs = super().indexAttributes
+ def _indexAttributes(self) -> dict[str, Any]:
+ attrs = super()._indexAttributes
attrs["src_str_col"] = self.src_str_col
return attrs
@@ -310,6 +368,17 @@ def indexAttributes(self) -> dict[str, Any]:
def src_str_col(self):
return self.component(self.__src_str_col)
+ def validate(self, df_schema: StructType) -> None:
+ super().validate(df_schema)
+ # make sure the parsed field exists
+ composite_idx_type: StructType = cast(StructType, df_schema[self.colname].dataType)
+ assert( self.__src_str_col in composite_idx_type,
+ f"The src_str_col column {self.src_str_col} does not exist in the composite field {composite_idx_type}")
+ # make sure it's StringType
+ src_str_field_type = composite_idx_type[self.__src_str_col].dataType
+ assert( isinstance(src_str_field_type, StringType),
+ f"The src_str_col column {self.src_str_col} should be of StringType, but found {src_str_field_type} instead" )
+
class ParsedTimestampIndex(ParsedTSIndex):
"""
@@ -317,15 +386,15 @@ class ParsedTimestampIndex(ParsedTSIndex):
"""
def __init__(
- self, composite_ts_idx: StructField, src_str_col: str, parsed_col: str
+ self, ts_idx: StructField, src_str_col: str, parsed_col: str
) -> None:
- super().__init__(composite_ts_idx, src_str_col, parsed_col)
+ super().__init__(ts_idx, src_str_col, parsed_col)
if not isinstance(self.primary_ts_idx.dataType, TimestampType):
raise TypeError(
- f"ParsedTimestampIndex must be of TimestampType, but given ts_col {self.primary_ts_idx.name} has type {self.primary_ts_idx.dataType}"
+ f"ParsedTimestampIndex must be of TimestampType, but given ts_col {self.primary_ts_idx.colname} has type {self.primary_ts_idx.dataType}"
)
- def rangeOrderByExpr(self, reverse: bool = False) -> Column:
+ def rangeExpr(self, reverse: bool = False) -> Column:
# cast timestamp to double (fractional seconds since epoch)
expr = Fn.col(self.primary_ts_col).cast("double")
return self._reverseOrNot(expr, reverse)
@@ -337,15 +406,15 @@ class ParsedDateIndex(ParsedTSIndex):
"""
def __init__(
- self, composite_ts_idx: StructField, src_str_col: str, parsed_col: str
+ self, ts_idx: StructField, src_str_col: str, parsed_col: str
) -> None:
- super().__init__(composite_ts_idx, src_str_col, parsed_col)
+ super().__init__(ts_idx, src_str_col, parsed_col)
if not isinstance(self.primary_ts_idx.dataType, DateType):
raise TypeError(
- f"ParsedDateIndex must be of DateType, but given ts_col {self.primary_ts_idx.name} has type {self.primary_ts_idx.dataType}"
+ f"ParsedDateIndex must be of DateType, but given ts_col {self.primary_ts_idx.colname} has type {self.primary_ts_idx.dataType}"
)
- def rangeOrderByExpr(self, reverse: bool = False) -> Column:
+ def rangeExpr(self, reverse: bool = False) -> Column:
# convert date to number of days since the epoch
expr = Fn.datediff(
Fn.col(self.primary_ts_col), Fn.lit("1970-01-01").cast("date")
@@ -364,11 +433,19 @@ class TSSchema:
"""
def __init__(self, ts_idx: TSIndex, series_ids: Collection[str] = None) -> None:
- self.ts_idx = ts_idx
+ self.__ts_idx = ts_idx
if series_ids:
- self.series_ids = list(series_ids)
+ self.__series_ids = list(series_ids)
else:
- self.series_ids = []
+ self.__series_ids = []
+
+ @property
+ def ts_idx(self):
+ return self.__ts_idx
+
+ @property
+ def series_ids(self) -> List[str]:
+ return self.__series_ids
def __eq__(self, o: object) -> bool:
# must be of TSSchema type
@@ -406,10 +483,15 @@ def structural_columns(self) -> list[str]:
:return: a set of column names corresponding the structural columns of a :class:`TSDF`
"""
- return list({self.ts_idx.name}.union(self.series_ids))
+ return list({self.ts_idx.colname}.union(self.series_ids))
def validate(self, df_schema: StructType) -> None:
- pass
+ # ensure that the TSIndex is valid
+ self.ts_idx.validate(df_schema)
+ # check series IDs
+ for sid in self.series_ids:
+ assert( sid in df_schema.fieldNames(),
+ f"Series ID {sid} does not exist in the given DataFrame" )
def find_observational_columns(self, df_schema: StructType) -> list[str]:
return list(set(df_schema.fieldNames()) - set(self.structural_columns))
diff --git a/python/tests/base.py b/python/tests/base.py
index e8bc52f1..cadc8f8d 100644
--- a/python/tests/base.py
+++ b/python/tests/base.py
@@ -178,8 +178,8 @@ def assertFieldsEqual(self, fieldA, fieldB):
Test that two fields are equivalent
"""
self.assertEqual(
- fieldA.name.lower(),
- fieldB.name.lower(),
+ fieldA.colname.lower(),
+ fieldB.colname.lower(),
msg=f"Field {fieldA} has different name from {fieldB}",
)
self.assertEqual(
@@ -195,9 +195,9 @@ def assertSchemaContainsField(self, schema, field):
"""
# the schema must contain a field with the right name
lc_fieldNames = [fc.lower() for fc in schema.fieldNames()]
- self.assertTrue(field.name.lower() in lc_fieldNames)
+ self.assertTrue(field.colname.lower() in lc_fieldNames)
# the attributes of the fields must be equal
- self.assertFieldsEqual(field, schema[field.name])
+ self.assertFieldsEqual(field, schema[field.colname])
@staticmethod
def assertDataFrameEquality(
diff --git a/python/tests/interpol_tests.py b/python/tests/interpol_tests.py
index 149603ad..d38ea1ed 100644
--- a/python/tests/interpol_tests.py
+++ b/python/tests/interpol_tests.py
@@ -395,11 +395,8 @@ def test_interpolation_using_custom_params(self):
simple_input_tsdf: TSDF = self.get_data_as_tsdf("simple_input_data")
expected_df: DataFrame = self.get_data_as_sdf("expected")
- input_tsdf = TSDF(
- simple_input_tsdf.df.withColumnRenamed("event_ts", "other_ts_col"),
- ts_col="other_ts_col",
- series_ids=["partition_a", "partition_b"],
- )
+ input_tsdf = TSDF(simple_input_tsdf.df.withColumnRenamed("event_ts", "other_ts_col"), ts_col="other_ts_col",
+ series_ids=["partition_a", "partition_b"])
actual_df: DataFrame = input_tsdf.interpolate(
ts_col="other_ts_col",
diff --git a/python/tests/tsdf_tests.py b/python/tests/tsdf_tests.py
index 56098f7b..27eab7e0 100644
--- a/python/tests/tsdf_tests.py
+++ b/python/tests/tsdf_tests.py
@@ -70,9 +70,7 @@ def __tsdf_with_double_tscol(tsdf: TSDF) -> TSDF:
with_double_tscol_df = tsdf.df.withColumn(
tsdf.ts_col, Fn.col(tsdf.ts_col).cast("double")
)
- return TSDF(
- with_double_tscol_df, ts_col=tsdf.ts_col, series_ids=tsdf.series_ids
- )
+ return TSDF(with_double_tscol_df, ts_col=tsdf.ts_col, series_ids=tsdf.series_ids)
# TODO - replace this with test code for TSDF.fromTimestampString with nano-second precision
# def test__add_double_ts(self):