Skip to content

Commit

Permalink
CNAM-154 Important fixes in MLPP featuring
Browse files Browse the repository at this point in the history
- Trackloss and diagnostic are no longer stop conditions for the featuring
- Added the option to remove patients who had a single trackloss
  • Loading branch information
danielpes committed Dec 6, 2016
1 parent f384aa6 commit 5ed7fb1
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 34 deletions.
2 changes: 2 additions & 0 deletions src/main/resources/config/filtering-default.conf
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ default = {
start_delay = 0
purchases_window = 0
only_first = false
filter_lost_patients = false
filter_diagnosed_patients = true
diagnosed_patients_threshold = 0
filter_delayed_entries = true
delayed_entry_threshold = 12
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ object MLPPConfig {
startDelay: Int,
purchasesWindow: Int,
onlyFirst: Boolean,
filterLostPatients: Boolean,
filterDiagnosedPatients: Boolean,
diagnosedPatientsThreshold: Int,
filterDelayedEntries: Boolean,
delayedEntryThreshold: Int
)
Expand All @@ -30,7 +32,9 @@ object MLPPConfig {
startDelay = conf.getInt("exposures.start_delay"),
purchasesWindow = conf.getInt("exposures.purchases_window"),
onlyFirst = conf.getBoolean("exposures.only_first"),
filterLostPatients = conf.getBoolean("exposures.filter_lost_patients"),
filterDiagnosedPatients = conf.getBoolean("exposures.filter_diagnosed_patients"),
diagnosedPatientsThreshold = conf.getInt("exposures.diagnosed_patients_threshold"),
filterDelayedEntries = conf.getBoolean("exposures.filter_delayed_entries"),
delayedEntryThreshold = conf.getInt("exposures.delayed_entry_threshold")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ object MLPPExposuresTransformer extends ExposuresTransformer {
private def startDelay = MLPPConfig.exposureDefinition.startDelay
private def purchasesWindow = MLPPConfig.exposureDefinition.purchasesWindow
private def onlyFirstExposure = MLPPConfig.exposureDefinition.onlyFirst
private def filterLostPatients = MLPPConfig.exposureDefinition.filterLostPatients
private def filterDelayedEntries = MLPPConfig.exposureDefinition.filterDelayedEntries
private def delayedEntryThreshold = MLPPConfig.exposureDefinition.delayedEntryThreshold
private def filterDiagnosedPatients = MLPPConfig.exposureDefinition.filterDiagnosedPatients
private def diagnosedPatientsThreshold = MLPPConfig.exposureDefinition.diagnosedPatientsThreshold

val outputColumns = List(
col("patientID"),
Expand All @@ -31,17 +33,22 @@ object MLPPExposuresTransformer extends ExposuresTransformer {
implicit class ExposuresDataFrame(data: DataFrame) {

/**
* Drops patients whose got a target disease before periodStart
* Drops patients whose got a target disease before periodStart + delay (default = 0)
*/
def filterDiagnosedPatients(doFilter: Boolean): DataFrame = {

if (doFilter) {
val window = Window.partitionBy("patientID")

val dateThreshold: Column = add_months(
lit(StudyStart), diagnosedPatientsThreshold
).cast(TimestampType)

val filterColumn: Column = min(
when(
col("category") === "disease" &&
col("eventId") === "targetDisease" &&
(col("start") < StudyStart), lit(0)
(col("category") === "disease") &&
(col("eventId") === "targetDisease") &&
(col("start") < dateThreshold), lit(0)
).otherwise(lit(1))
).over(window).cast(BooleanType)

Expand All @@ -53,7 +60,7 @@ object MLPPExposuresTransformer extends ExposuresTransformer {
}

/**
* Drops patients whose first molecule event is after StudyStart + 1 year
* Drops patients whose first molecule event is after StudyStart + delay (default: 1 year)
*/
def filterDelayedEntries(doFilter: Boolean): DataFrame = {
if (doFilter) {
Expand All @@ -77,6 +84,26 @@ object MLPPExposuresTransformer extends ExposuresTransformer {
}
}

/**
* Drops patients with trackloss events
*/
def filterLostPatients(doFilter: Boolean): DataFrame = {
if (doFilter) {
val window = Window.partitionBy("patientID")
val filterColumn: Column = min(
when(
col("category") === "trackloss" && (col("start") >= StudyStart),
lit(0)
).otherwise(lit(1))
).over(window).cast(BooleanType)

data.withColumn("filter", filterColumn).where(col("filter")).drop("filter")
}
else {
data
}
}

def withExposureStart(minPurchases: Int = 1, intervalSize: Int = 6,
startDelay: Int = 0, firstOnly: Boolean = false): DataFrame = {

Expand Down Expand Up @@ -116,6 +143,7 @@ object MLPPExposuresTransformer extends ExposuresTransformer {
input.toDF
.filterDelayedEntries(filterDelayedEntries)
.filterDiagnosedPatients(filterDiagnosedPatients)
.filterLostPatients(filterLostPatients)
.where(col("category") === "molecule")
.withExposureStart(
minPurchases = minPurchases,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
package fr.polytechnique.cmap.cnam.filtering.mlpp

import org.apache.spark.sql.Dataset
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.functions._
import fr.polytechnique.cmap.cnam.Main
import fr.polytechnique.cmap.cnam.filtering.{FilteringConfig, FilteringMain, FlatEvent}
import fr.polytechnique.cmap.cnam.filtering._

object MLPPMain extends Main {

override def appName: String = "MLPPFeaturing"

def run(sqlContext: HiveContext, argsMap: Map[String, String] = Map()): Option[Dataset[MLPPFeature]] = {

import sqlContext.implicits._

// "get" returns an Option, then we can use foreach to gently ignore when the key was not found.
argsMap.get("conf").foreach(sqlContext.setConf("conf", _))
argsMap.get("env").foreach(sqlContext.setConf("env", _))
Expand All @@ -20,7 +23,23 @@ object MLPPMain extends Main {
.filter(e => e.category == "molecule" || e.category == "disease").cache()

val diseaseEvents: Dataset[FlatEvent] = flatEvents.filter(_.category == "disease")
val exposures: Dataset[FlatEvent] = MLPPExposuresTransformer.transform(flatEvents)
val dcirFlat: DataFrame = sqlContext.read.parquet(FilteringConfig.inputPaths.dcir)

val patients: Dataset[Patient] = flatEvents.map(
e => Patient(e.patientID, e.gender, e.birthDate, e.deathDate)
).distinct
val tracklossEvents: Dataset[Event] = TrackLossTransformer.transform(
Sources(dcir=Some(dcirFlat))
)
val tracklossFlatEvents = tracklossEvents
.as("left")
.joinWith(patients.as("right"), col("left.patientID") === col("right.patientID"))
.map((FlatEvent.merge _).tupled)
.cache()

val allEvents = flatEvents.union(tracklossFlatEvents)

val exposures: Dataset[FlatEvent] = MLPPExposuresTransformer.transform(allEvents)

val mlppParams = MLPPWriter.Params(
bucketSize = MLPPConfig.bucketSize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,18 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {

val hadDisease: Column = (col("category") === "disease") &&
(col("eventId") === "targetDisease") &&
(col("startBucket") < minColumn(col("tracklossBucket"), col("deathBucket"), lit(bucketCount)))
(col("startBucket") < minColumn(col("deathBucket"), lit(bucketCount)))

val diseaseBucket: Column = min(when(hadDisease, col("startBucket"))).over(window)

data.withColumn("diseaseBucket", diseaseBucket)
}

// We are no longer using trackloss and disease information for calculating the end bucket.
def withEndBucket: DataFrame = {

val endBucket: Column = minColumn(
col("tracklossBucket"), col("diseaseBucket"), col("deathBucket"), lit(bucketCount)
col("deathBucket"), lit(bucketCount)
)
data.withColumn("endBucket", endBucket)
}
Expand Down Expand Up @@ -267,7 +268,6 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {
.withAge(AgeReferenceDate)
.withStartBucket
.withDeathBucket
.withTracklossBucket
.withDiseaseBucket
.withEndBucket
.where(col("category") === "exposure")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class MLPPExposuresTransformerSuite extends SharedContext {
("Patient_A", "molecule", "", makeTS(2008, 1, 10)),
("Patient_A", "disease", "targetDisease", makeTS(2005, 1, 1)),
("Patient_B", "molecule", "", makeTS(2009, 1, 1)),
("Patient_B", "disease", "targetDisease", makeTS(2009, 1, 1)),
("Patient_B", "disease", "targetDisease", makeTS(2006, 1, 1)),
("Patient_C", "molecule", "", makeTS(2006, 1, 1))
).toDF("patientID", "category", "eventId", "start")

Expand All @@ -81,6 +81,8 @@ class MLPPExposuresTransformerSuite extends SharedContext {

// Then
import RichDataFrames._
result.show
expected.show
assert(result === expected)
}

Expand All @@ -105,6 +107,34 @@ class MLPPExposuresTransformerSuite extends SharedContext {
assert(result === expected)
}

"filterLostPatients" should "remove patients when they have a trackloss events" in {
val sqlCtx = sqlContext
import sqlCtx.implicits._

// Given
val input = Seq(
("Patient_A", "molecule", makeTS(2006, 1, 1)),
("Patient_A", "molecule", makeTS(2006, 2, 1)),
("Patient_B", "molecule", makeTS(2006, 5, 1)),
("Patient_B", "trackloss", makeTS(2007, 1, 1)),
("Patient_C", "molecule", makeTS(2006, 11, 1))
).toDF("patientID", "category", "start")

val expected = Seq(
("Patient_A", "molecule", makeTS(2006, 1, 1)),
("Patient_A", "molecule", makeTS(2006, 2, 1)),
("Patient_C", "molecule", makeTS(2006, 11, 1))
).toDF("patientID", "category", "start")

// When
import MLPPExposuresTransformer.ExposuresDataFrame
val result = input.filterLostPatients(true)

// Then
import RichDataFrames._
assert(result === expected)
}

"withExposureStart" should "add a column with the start of the default MLPP exposure definition" in {
val sqlCtx = sqlContext
import sqlCtx.implicits._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,26 @@ class MLPPMainSuite extends SharedContext {
lazy val featuresPath = FilteringConfig.outputPaths.mlppFeatures

val expectedFeatures: DataFrame = Seq(
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 0, 0, 0, 0, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 1, 0, 1, 0, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 1, 1, 1, 1, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 2, 1, 2, 1, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 2, 2, 2, 2, 1.0)
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 0, 0, 0, 0, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 1, 0, 1, 0, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 1, 1, 1, 1, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 2, 1, 2, 1, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 2, 2, 2, 2, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 3, 2, 3, 2, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 3, 3, 3, 3, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 4, 3, 4, 3, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 4, 4, 4, 4, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 5, 4, 5, 4, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 5, 5, 5, 5, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 6, 5, 6, 5, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 6, 6, 6, 6, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 7, 6, 7, 6, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 7, 7, 7, 7, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 8, 7, 8, 7, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 8, 8, 8, 8, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 9, 8, 9, 8, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 9, 9, 9, 9, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 10, 9, 10, 9, 1.0)
).toDF

// When
Expand All @@ -56,7 +71,14 @@ class MLPPMainSuite extends SharedContext {
val expectedFeatures: DataFrame = Seq(
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 0, 0, 0, 0, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 1, 1, 1, 1, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 2, 2, 2, 2, 1.0)
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 2, 2, 2, 2, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 3, 3, 3, 3, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 4, 4, 4, 4, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 5, 5, 5, 5, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 6, 6, 6, 6, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 7, 7, 7, 7, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 8, 8, 8, 8, 1.0),
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 9, 9, 9, 9, 1.0)
).toDF

// When
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,14 @@ class MLPPWriterSuite extends SharedContext {
val expected = Seq(
("PA", Some(2)),
("PA", Some(2)),
("PB", Some(3)),
("PB", Some(3)),
("PC", Some(4)),
("PC", Some(4)),
("PB", Some(4)),
("PB", Some(4)),
("PC", Some(16)),
("PC", Some(16)),
("PD", Some(5)),
("PD", Some(5)),
("PE", Some(6)),
("PE", Some(6)),
("PE", Some(7)),
("PE", Some(7)),
("PF", Some(16))
).toDF("patientID", "endBucket")

Expand Down Expand Up @@ -605,16 +605,26 @@ class MLPPWriterSuite extends SharedContext {

val expectedFeatures = Seq(
// Patient A
MLPPFeature("PA", 0, "Mol1", 0, 0, 0, 0, 0, 1.0),
MLPPFeature("PA", 0, "Mol1", 0, 1, 1, 1, 1, 1.0),
MLPPFeature("PA", 0, "Mol1", 0, 2, 2, 2, 2, 1.0),
MLPPFeature("PA", 0, "Mol1", 0, 3, 3, 3, 3, 1.0),
MLPPFeature("PA", 0, "Mol1", 0, 2, 0, 2, 0, 1.0),
MLPPFeature("PA", 0, "Mol1", 0, 3, 1, 3, 1, 1.0),
MLPPFeature("PA", 0, "Mol1", 0, 3, 0, 3, 0, 1.0),
MLPPFeature("PA", 0, "Mol2", 1, 2, 0, 2, 4, 1.0),
MLPPFeature("PA", 0, "Mol2", 1, 3, 1, 3, 5, 1.0),
MLPPFeature("PA", 0, "Mol3", 2, 3, 0, 3, 8, 1.0)
MLPPFeature("PA", 0, "Mol1", 0, 0, 0, 0, 0, 1.0),
MLPPFeature("PA", 0, "Mol1", 0, 1, 1, 1, 1, 1.0),
MLPPFeature("PA", 0, "Mol1", 0, 2, 2, 2, 2, 1.0),
MLPPFeature("PA", 0, "Mol1", 0, 3, 3, 3, 3, 1.0),
MLPPFeature("PA", 0, "Mol1", 0, 2, 0, 2, 0, 1.0),
MLPPFeature("PA", 0, "Mol1", 0, 3, 1, 3, 1, 1.0),
MLPPFeature("PA", 0, "Mol1", 0, 4, 2, 4, 2, 1.0),
MLPPFeature("PA", 0, "Mol1", 0, 5, 3, 5, 3, 1.0),
MLPPFeature("PA", 0, "Mol1", 0, 3, 0, 3, 0, 1.0),
MLPPFeature("PA", 0, "Mol1", 0, 4, 1, 4, 1, 1.0),
MLPPFeature("PA", 0, "Mol1", 0, 5, 2, 5, 2, 1.0),
MLPPFeature("PA", 0, "Mol1", 0, 6, 3, 6, 3, 1.0),
MLPPFeature("PA", 0, "Mol2", 1, 2, 0, 2, 4, 1.0),
MLPPFeature("PA", 0, "Mol2", 1, 3, 1, 3, 5, 1.0),
MLPPFeature("PA", 0, "Mol2", 1, 4, 2, 4, 6, 1.0),
MLPPFeature("PA", 0, "Mol2", 1, 5, 3, 5, 7, 1.0),
MLPPFeature("PA", 0, "Mol3", 2, 3, 0, 3, 8, 1.0),
MLPPFeature("PA", 0, "Mol3", 2, 4, 1, 4, 9, 1.0),
MLPPFeature("PA", 0, "Mol3", 2, 5, 2, 5, 10, 1.0),
MLPPFeature("PA", 0, "Mol3", 2, 6, 3, 6, 11, 1.0)
).toDF

val expectedZMatrix = Seq(
Expand Down

0 comments on commit 5ed7fb1

Please sign in to comment.