Skip to content

Commit

Permalink
New Survival Data Endpoint for Dynamic Kaplan-Meier Curve (#10774)
Browse files Browse the repository at this point in the history
Add a new survival data api endpoint that supports customizable Dynamic Kaplan-Meier plot.
---------

Co-authored-by: Karthik <[email protected]>
Co-authored-by: Qi-Xuan Lu <[email protected]>
  • Loading branch information
3 people authored Jun 12, 2024
1 parent 387509c commit 52cbbcf
Show file tree
Hide file tree
Showing 24 changed files with 1,350 additions and 169 deletions.
2 changes: 1 addition & 1 deletion src/main/java/org/cbioportal/model/ClinicalData.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public String getAttrId() {
}

public Boolean isPatientAttribute() {
if(clinicalAttribute == null) {
if (clinicalAttribute == null) {
return null;
}
return this.clinicalAttribute.getPatientAttribute();
Expand Down
17 changes: 16 additions & 1 deletion src/main/java/org/cbioportal/model/ClinicalEvent.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package org.cbioportal.model;

import java.util.List;
import jakarta.validation.constraints.NotNull;

import java.util.List;
import java.util.Objects;

public class ClinicalEvent extends UniqueKeyBase {

private Integer clinicalEventId;
Expand Down Expand Up @@ -71,4 +73,17 @@ public List<ClinicalEventData> getAttributes() {
public void setAttributes(List<ClinicalEventData> attributes) {
this.attributes = attributes;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClinicalEvent that = (ClinicalEvent) o;
return Objects.equals(clinicalEventId, that.clinicalEventId) && Objects.equals(studyId, that.studyId) && Objects.equals(patientId, that.patientId) && Objects.equals(eventType, that.eventType) && Objects.equals(startDate, that.startDate) && Objects.equals(stopDate, that.stopDate) && Objects.equals(attributes, that.attributes);
}

@Override
public int hashCode() {
return Objects.hash(clinicalEventId, studyId, patientId, eventType, startDate, stopDate, attributes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,11 @@ List<ClinicalEvent> getAllClinicalEventsInStudy(String studyId, String projectio

@Cacheable(cacheResolver = "generalRepositoryCacheResolver", condition = "@cacheEnabledConfig.getEnabled()")
List<ClinicalEvent> getPatientsDistinctClinicalEventInStudies(List<String> studyIds, List<String> patientIds);


@Cacheable(cacheResolver = "generalRepositoryCacheResolver", condition = "@cacheEnabledConfig.getEnabled()")
List<ClinicalEvent> getTimelineEvents(List<String> studyIds, List<String> patientIds, List<ClinicalEvent> clinicalEvents);

@Cacheable(cacheResolver = "generalRepositoryCacheResolver", condition = "@cacheEnabledConfig.getEnabled()")
List<ClinicalEvent> getClinicalEventsMeta(List<String> studyIds, List<String> patientIds, List<ClinicalEvent> clinicalEvents);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,8 @@ List<ClinicalEvent> getStudyClinicalEvent(String studyId, String projection, Int
List<ClinicalEvent> getSamplesOfPatientsPerEventType(List<String> studyIds, List<String> sampleIds);

List<ClinicalEvent> getPatientsDistinctClinicalEventInStudies(List<String> studyIds, List<String> patientIds);

List<ClinicalEvent> getTimelineEvents(List<String> studyIds, List<String> patientIds, List<ClinicalEvent> clinicalEvents);

List<ClinicalEvent> getClinicalEventsMeta(List<String> studyIds, List<String> patientIds, List<ClinicalEvent> clinicalEvents);
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,16 @@ public Map<String, Set<String>> getSamplesOfPatientsPerEventTypeInStudy(List<Str
public List<ClinicalEvent> getPatientsDistinctClinicalEventInStudies(List<String> studyIds, List<String> patientIds) {
return clinicalEventMapper.getPatientsDistinctClinicalEventInStudies(studyIds, patientIds);
}

@Override
public List<ClinicalEvent> getTimelineEvents(List<String> studyIds,
List<String> patientIds,
List<ClinicalEvent> clinicalEvents) {
return clinicalEventMapper.getTimelineEvents(studyIds, patientIds, clinicalEvents);
}

@Override
public List<ClinicalEvent> getClinicalEventsMeta(List<String> studyIds, List<String> patientIds, List<ClinicalEvent> clinicalEvents) {
return clinicalEventMapper.getClinicalEventsMeta(studyIds, patientIds, clinicalEvents);
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package org.cbioportal.service;

import org.cbioportal.model.ClinicalData;
import org.cbioportal.model.ClinicalEvent;
import org.cbioportal.model.ClinicalEventTypeCount;
import org.cbioportal.model.meta.BaseMeta;
import org.cbioportal.service.exception.PatientNotFoundException;
import org.cbioportal.service.exception.StudyNotFoundException;
import org.cbioportal.web.parameter.SurvivalRequest;

import java.util.List;
import java.util.Map;
Expand All @@ -29,4 +31,11 @@ BaseMeta getMetaClinicalEvents(String studyId)
Map<String, Set<String>> getPatientsSamplesPerClinicalEventType(List<String> studyIds, List<String> sampleIds);

List<ClinicalEventTypeCount> getClinicalEventTypeCounts(List<String> studyIds, List<String> sampleIds);

List<ClinicalData> getSurvivalData(List<String> studyIds, List<String> patientIds,
String attributeIdPrefix,
SurvivalRequest survivalRequest);

List<ClinicalEvent> getClinicalEventsMeta(List<String> studyIds, List<String> patientIds,
List<ClinicalEvent> clinicalEvents);
}
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ public enum FrontendProperty {
skin_patient_view_copy_number_table_columns_show_on_init("skin.patient_view.copy_number_table.columns.show_on_init", null),
skin_patient_view_structural_variant_table_columns_show_on_init("skin.patient_view.structural_variant_table.columns.show_on_init", null),
skin_results_view_tables_default_sort_column("skin.results_view.tables.default_sort_column", null),
skin_survival_plot_clinical_event_types_show_on_init("skin.survival_plot.clinical_event_types.show_on_init", null),

skin_patient_view_tables_default_sort_column("skin.patient_view.tables.default_sort_column", null),
enable_treatment_groups("enable_treatment_groups", null),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
package org.cbioportal.service.impl;

import org.cbioportal.model.*;
import org.apache.commons.collections4.CollectionUtils;
import org.cbioportal.model.ClinicalData;
import org.cbioportal.model.ClinicalEvent;
import org.cbioportal.model.ClinicalEventData;
import org.cbioportal.model.ClinicalEventTypeCount;
import org.cbioportal.model.Patient;
import org.cbioportal.model.meta.BaseMeta;
import org.cbioportal.persistence.ClinicalEventRepository;
import org.cbioportal.service.ClinicalEventService;
import org.cbioportal.service.PatientService;
import org.cbioportal.service.exception.PatientNotFoundException;
import org.cbioportal.service.exception.StudyNotFoundException;
import org.cbioportal.web.parameter.ClinicalEventRequestIdentifier;
import org.cbioportal.web.parameter.OccurrencePosition;
import org.cbioportal.web.parameter.SurvivalRequest;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.util.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.ToIntFunction;
import java.util.stream.Collectors;
import java.util.stream.Stream;

@Service
public class ClinicalEventServiceImpl implements ClinicalEventService {
Expand Down Expand Up @@ -108,4 +124,137 @@ public List<ClinicalEventTypeCount> getClinicalEventTypeCounts(List<String> stud
.map(e -> new ClinicalEventTypeCount(e.getKey(),e.getValue()))
.collect(Collectors.toList());
}

@Override
public List<ClinicalData> getSurvivalData(List<String> studyIds,
List<String> patientIds,
String attributeIdPrefix,
SurvivalRequest survivalRequest) {
List<ClinicalEvent> startClinicalEventsMeta = getToClinicalEvents(survivalRequest.getStartEventRequestIdentifier());
List<ClinicalEvent> patientStartEvents = clinicalEventRepository.getTimelineEvents(studyIds, patientIds, startClinicalEventsMeta);

// only fetch end timeline events for patients that have endClinicalEventsMeta and start timeline events
List<ClinicalEvent> patientEndEvents = filterClinicalEvents(patientStartEvents, survivalRequest.getEndEventRequestIdentifier());

ToIntFunction<ClinicalEvent> startPositionIdentifier = getPositionIdentifier(survivalRequest.getStartEventRequestIdentifier().getPosition());
ToIntFunction<ClinicalEvent> endPositionIdentifier = getPositionIdentifier(survivalRequest.getEndEventRequestIdentifier().getPosition());
Map<String, ClinicalEvent> patientEndEventsById = patientEndEvents.stream().collect(Collectors.toMap(ClinicalEventServiceImpl::getKey, Function.identity()));

// filter out cases where start event is less than end event
patientStartEvents = patientStartEvents.stream()
.filter(event ->
Optional.ofNullable(patientEndEventsById.get(getKey(event)))
.map(endPositionIdentifier::applyAsInt)
.map(endDate -> startPositionIdentifier.applyAsInt(event) < endDate)
.orElse(true)
).toList();

List<ClinicalEvent> patientCensoredEvents = filterClinicalEvents(patientStartEvents, survivalRequest.getCensoredEventRequestIdentifier());

return patientStartEvents.stream()
.flatMap(event -> {
ClinicalData clinicalDataMonths = buildClinicalSurvivalMonths(attributeIdPrefix, event, survivalRequest, patientEndEvents, patientCensoredEvents);
if (clinicalDataMonths == null) return Stream.empty();
ClinicalData clinicalDataStatus = buildClinicalSurvivalStatus(attributeIdPrefix, event, patientEndEvents);

return Stream.of(clinicalDataMonths, clinicalDataStatus);
}).toList();
}

@Override
public List<ClinicalEvent> getClinicalEventsMeta(List<String> studyIds, List<String> patientIds, List<ClinicalEvent> clinicalEvents) {
return clinicalEventRepository.getClinicalEventsMeta(studyIds, patientIds, clinicalEvents);
}

private static String getKey(ClinicalEvent clinicalEvent) {
return clinicalEvent.getStudyId() + clinicalEvent.getPatientId();
}

private static List<ClinicalEvent> getToClinicalEvents(ClinicalEventRequestIdentifier clinicalEventRequestIdentifier) {
return clinicalEventRequestIdentifier.getClinicalEventRequests().stream().map(x -> {
ClinicalEvent clinicalEvent = new ClinicalEvent();
clinicalEvent.setEventType(x.getEventType());
clinicalEvent.setAttributes(x.getAttributes());

return clinicalEvent;
}).toList();
}

private ToIntFunction<ClinicalEvent> getPositionIdentifier(OccurrencePosition position) {
return position.equals(OccurrencePosition.FIRST) ? ClinicalEvent::getStartDate : ClinicalEvent::getStopDate;
}

private List<ClinicalEvent> filterClinicalEvents(List<ClinicalEvent> patientEvents,
ClinicalEventRequestIdentifier clinicalEventRequestIdentifier) {
List<String> filteredStudyIds = new ArrayList<>();
List<String> filteredPatientIds = new ArrayList<>();
for (ClinicalEvent clinicalEvent : patientEvents) {
filteredStudyIds.add(clinicalEvent.getStudyId());
filteredPatientIds.add(clinicalEvent.getPatientId());
}

List<ClinicalEvent> clinicalEventsMeta = new ArrayList<>();
if (clinicalEventRequestIdentifier != null) {
clinicalEventsMeta = getToClinicalEvents(clinicalEventRequestIdentifier);
}

// only fetch end timeline events for patients that have endClinicalEventsMeta and start timeline events
List<ClinicalEvent> queriedPatientEvents = new ArrayList<>();
if (CollectionUtils.isNotEmpty(clinicalEventsMeta) && CollectionUtils.isNotEmpty(filteredStudyIds)) {
queriedPatientEvents = clinicalEventRepository.getTimelineEvents(filteredStudyIds, filteredPatientIds, clinicalEventsMeta);
}
return queriedPatientEvents;
}

private ClinicalData buildClinicalSurvivalMonths(String attributeIdPrefix, ClinicalEvent event, SurvivalRequest survivalRequest, List<ClinicalEvent> patientEndEvents, List<ClinicalEvent> patientCensoredEvents) {
final String SURVIVAL_MONTH_ATTRIBUTE = attributeIdPrefix + "_MONTHS";
ClinicalData clinicalDataMonths = new ClinicalData();
clinicalDataMonths.setStudyId(event.getStudyId());
clinicalDataMonths.setPatientId(event.getPatientId());
clinicalDataMonths.setAttrId(SURVIVAL_MONTH_ATTRIBUTE);

Map<String, ClinicalEvent> patientEndEventsById = patientEndEvents.stream().collect(Collectors.toMap(ClinicalEventServiceImpl::getKey, Function.identity()));
Map<String, ClinicalEvent> patientCensoredEventsById = patientCensoredEvents.stream().collect(Collectors.toMap(ClinicalEventServiceImpl::getKey, Function.identity()));

ToIntFunction<ClinicalEvent> startPositionIdentifier = getPositionIdentifier(survivalRequest.getStartEventRequestIdentifier().getPosition());
ToIntFunction<ClinicalEvent> endPositionIdentifier = survivalRequest.getEndEventRequestIdentifier() == null ? ClinicalEvent::getStopDate : getPositionIdentifier(survivalRequest.getEndEventRequestIdentifier().getPosition());
ToIntFunction<ClinicalEvent> censoredPositionIdentifier = survivalRequest.getCensoredEventRequestIdentifier() == null ? ClinicalEvent::getStopDate : getPositionIdentifier(survivalRequest.getCensoredEventRequestIdentifier().getPosition());

int startDate = startPositionIdentifier.applyAsInt(event);
int endDate;
if (patientEndEventsById.containsKey(getKey(event))) {
endDate = endPositionIdentifier.applyAsInt(patientEndEventsById.get(getKey(event)));
} else {
// ignore cases where patient does not have censored timeline events or
// stop date of start event is less than start date of censored events
if (!patientCensoredEventsById.containsKey(getKey(event)) ||
startDate >= censoredPositionIdentifier.applyAsInt(patientCensoredEventsById.get(getKey(event)))
) {
return null;
}

endDate = censoredPositionIdentifier.applyAsInt(patientCensoredEventsById.get(getKey(event)));
}
final String SURVIVAL_MONTH = String.valueOf((endDate - startDate) / 30.4);
clinicalDataMonths.setAttrValue(SURVIVAL_MONTH);

return clinicalDataMonths;
}

private ClinicalData buildClinicalSurvivalStatus(String attributeIdPrefix, ClinicalEvent event, List<ClinicalEvent> patientEndEvents) {
Map<String, ClinicalEvent> patientEndEventsById = patientEndEvents.stream().collect(Collectors.toMap(ClinicalEventServiceImpl::getKey, Function.identity()));

ClinicalData clinicalDataStatus = new ClinicalData();
clinicalDataStatus.setStudyId(event.getStudyId());
clinicalDataStatus.setPatientId(event.getPatientId());
clinicalDataStatus.setAttrId(attributeIdPrefix + "_STATUS");

if (patientEndEventsById.containsKey(getKey(event))) {
clinicalDataStatus.setAttrValue("1:EVENT");
} else {
clinicalDataStatus.setAttrValue("0:CENSORED");
}

return clinicalDataStatus;
}
}
45 changes: 45 additions & 0 deletions src/main/java/org/cbioportal/web/ClinicalEventController.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
import jakarta.validation.constraints.Max;
import jakarta.validation.constraints.Min;
import org.cbioportal.model.ClinicalEvent;
Expand All @@ -15,9 +16,11 @@
import org.cbioportal.service.exception.StudyNotFoundException;
import org.cbioportal.web.config.InternalApiTags;
import org.cbioportal.web.config.annotation.InternalApi;
import org.cbioportal.web.parameter.ClinicalEventAttributeRequest;
import org.cbioportal.web.parameter.Direction;
import org.cbioportal.web.parameter.HeaderKeyConstants;
import org.cbioportal.web.parameter.PagingConstants;
import org.cbioportal.web.parameter.PatientIdentifier;
import org.cbioportal.web.parameter.Projection;
import org.cbioportal.web.parameter.sort.ClinicalEventSortBy;
import org.springframework.beans.factory.annotation.Autowired;
Expand All @@ -28,11 +31,15 @@
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestAttribute;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

@InternalApi
Expand Down Expand Up @@ -120,4 +127,42 @@ public ResponseEntity<List<ClinicalEvent>> getAllClinicalEventsInStudy(
sortBy == null ? null : sortBy.getOriginalValue(), direction.name()), HttpStatus.OK);
}
}

@PreAuthorize("hasPermission(#involvedCancerStudies, 'Collection<CancerStudyId>', T(org.cbioportal.utils.security.AccessLevel).READ)")
@RequestMapping(value = "/clinical-events-meta/fetch",
method = RequestMethod.POST,
consumes = MediaType.APPLICATION_JSON_VALUE,
produces = MediaType.APPLICATION_JSON_VALUE)
@Operation(description = "Fetch clinical events meta")
@ApiResponse(responseCode = "200", description = "OK",
content = @Content(array = @ArraySchema(schema = @Schema(implementation = ClinicalEvent.class))))
public ResponseEntity<List<ClinicalEvent>> fetchClinicalEventsMeta(
@Parameter(required = true, description = "clinical events Request")
@Valid @RequestBody(required = false) ClinicalEventAttributeRequest clinicalEventAttributeRequest,
@Parameter(hidden = true) // prevent reference to this attribute in the swagger-ui interface
@RequestAttribute(required = false, value = "involvedCancerStudies") Collection<String> involvedCancerStudies,
@Parameter(hidden = true) // prevent reference to this attribute in the swagger-ui interface. This attribute is needed for the @PreAuthorize tag above.
@Valid @RequestAttribute(required = false, value = "interceptedClinicalEventAttributeRequest") ClinicalEventAttributeRequest interceptedClinicalEventAttributeRequest) {

List<String> studyIds = new ArrayList<>();
List<String> patientIds = new ArrayList<>();
for (PatientIdentifier patientIdentifier : interceptedClinicalEventAttributeRequest.getPatientIdentifiers()) {
studyIds.add(patientIdentifier.getStudyId());
patientIds.add(patientIdentifier.getPatientId());
}

List<ClinicalEvent> clinicalEventsRequest = interceptedClinicalEventAttributeRequest.getClinicalEventRequests()
.stream()
.map(x -> {
ClinicalEvent clinicalEvent = new ClinicalEvent();
clinicalEvent.setEventType(x.getEventType());
clinicalEvent.setAttributes(x.getAttributes());
return clinicalEvent;
})
.toList();

return new ResponseEntity<>(clinicalEventService.getClinicalEventsMeta(
studyIds, patientIds, clinicalEventsRequest), HttpStatus.OK);

}
}
Loading

0 comments on commit 52cbbcf

Please sign in to comment.