Skip to content

Commit

Permalink
Ensure queries returned via REST API are redacted
Browse files Browse the repository at this point in the history
@JsonConstructor for TrimmedBasicQueryInfo was introduced to facilitate
the deserialization of server responses in tests.
  • Loading branch information
piotrrzysko committed Jan 20, 2025
1 parent e2f080c commit 2490789
Show file tree
Hide file tree
Showing 5 changed files with 441 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.server.ui;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.errorprone.annotations.Immutable;
import io.trino.execution.QueryState;
Expand Down Expand Up @@ -54,6 +55,45 @@ public class TrimmedBasicQueryInfo
private final Optional<QueryType> queryType;
private final RetryPolicy retryPolicy;

@JsonCreator
public TrimmedBasicQueryInfo(
@JsonProperty("queryId") QueryId queryId,
@JsonProperty("sessionUser") String sessionUser,
@JsonProperty("sessionPrincipal") Optional<String> sessionPrincipal,
@JsonProperty("sessionSource") Optional<String> sessionSource,
@JsonProperty("resourceGroupId") Optional<ResourceGroupId> resourceGroupId,
@JsonProperty("queryDataEncoding") Optional<String> queryDataEncoding,
@JsonProperty("state") QueryState state,
@JsonProperty("scheduled") boolean scheduled,
@JsonProperty("self") URI self,
@JsonProperty("queryTextPreview") String queryTextPreview,
@JsonProperty("updateType") Optional<String> updateType,
@JsonProperty("preparedQuery") Optional<String> preparedQuery,
@JsonProperty("queryStats") BasicQueryStats queryStats,
@JsonProperty("errorType") Optional<ErrorType> errorType,
@JsonProperty("errorCode") Optional<ErrorCode> errorCode,
@JsonProperty("queryType") Optional<QueryType> queryType,
@JsonProperty("retryPolicy") RetryPolicy retryPolicy)
{
this.queryId = requireNonNull(queryId, "queryId is null");
this.sessionUser = requireNonNull(sessionUser, "sessionUser is null");
this.sessionPrincipal = requireNonNull(sessionPrincipal, "sessionPrincipal is null");
this.sessionSource = requireNonNull(sessionSource, "sessionSource is null");
this.resourceGroupId = requireNonNull(resourceGroupId, "resourceGroupId is null");
this.queryDataEncoding = requireNonNull(queryDataEncoding, "queryDataEncoding is null");
this.state = requireNonNull(state, "state is null");
this.scheduled = scheduled;
this.self = requireNonNull(self, "self is null");
this.queryTextPreview = requireNonNull(queryTextPreview, "queryTextPreview is null");
this.updateType = requireNonNull(updateType, "updateType is null");
this.preparedQuery = requireNonNull(preparedQuery, "preparedQuery is null");
this.queryStats = requireNonNull(queryStats, "queryStats is null");
this.errorType = requireNonNull(errorType, "errorType is null");
this.errorCode = requireNonNull(errorCode, "errorCode is null");
this.queryType = requireNonNull(queryType, "queryType is null");
this.retryPolicy = requireNonNull(retryPolicy, "retryPolicy is null");
}

public TrimmedBasicQueryInfo(BasicQueryInfo queryInfo)
{
this.queryId = requireNonNull(queryInfo.getQueryId(), "queryId is null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.server;

import com.google.common.collect.ImmutableSet;
import com.google.inject.Key;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.HttpUriBuilder;
Expand All @@ -29,6 +30,8 @@
import io.trino.client.QueryDataClientJacksonModule;
import io.trino.client.QueryResults;
import io.trino.client.ResultRowsDecoder;
import io.trino.connector.MockConnectorFactory;
import io.trino.connector.MockConnectorPlugin;
import io.trino.execution.QueryInfo;
import io.trino.plugin.tpch.TpchPlugin;
import io.trino.server.testing.TestingTrinoServer;
Expand Down Expand Up @@ -65,6 +68,7 @@
import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.KILL_QUERY;
import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.VIEW_QUERY;
import static io.trino.testing.TestingAccessControlManager.privilege;
import static io.trino.testing.TestingNames.randomNameSuffix;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand Down Expand Up @@ -94,6 +98,9 @@ public void setup()
{
client = new JettyHttpClient();
server = TestingTrinoServer.create();
server.installPlugin(new MockConnectorPlugin(MockConnectorFactory.builder()
.withSecuritySensitivePropertyNames(ImmutableSet.of("password"))
.build()));
server.installPlugin(new TpchPlugin());
server.createCatalog("tpch", "tpch");
}
Expand Down Expand Up @@ -226,6 +233,47 @@ public void testGetQueryInfoExecutionFailure()
assertThat(info.getFailureInfo().getErrorCode()).isEqualTo(DIVISION_BY_ZERO.toErrorCode());
}

@Test
public void testGetQueryInfosWithRedactedSecrets()
{
String catalog = "catalog_" + randomNameSuffix();
runToCompletion("""
CREATE CATALOG %s USING mock
WITH (
"user" = 'bob',
"password" = '1234'
)""".formatted(catalog));

List<BasicQueryInfo> infos = getQueryInfos("/v1/query");
assertThat(infos.size()).isEqualTo(1);
assertThat(infos.getFirst().getQuery()).isEqualTo("""
CREATE CATALOG %s USING mock
WITH (
"user" = 'bob',
"password" = '***'
)""".formatted(catalog));
}

@Test
public void testGetQueryInfoWithRedactedSecrets()
{
String catalog = "catalog_" + randomNameSuffix();
String queryId = runToCompletion("""
CREATE CATALOG %s USING mock
WITH (
"user" = 'bob',
"password" = '1234'
)""".formatted(catalog));

QueryInfo queryInfo = getQueryInfo(queryId);
assertThat(queryInfo.getQuery()).isEqualTo("""
CREATE CATALOG %s USING mock
WITH (
"user" = 'bob',
"password" = '***'
)""".formatted(catalog));
}

@Test
public void testCancel()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.server;

import com.google.common.collect.ImmutableSet;
import com.google.common.io.Closer;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.Request;
Expand All @@ -23,6 +24,8 @@
import io.airlift.json.ObjectMapperProvider;
import io.airlift.units.Duration;
import io.trino.client.QueryResults;
import io.trino.connector.MockConnectorFactory;
import io.trino.connector.MockConnectorPlugin;
import io.trino.plugin.tpch.TpchPlugin;
import io.trino.server.protocol.spooling.QueryDataJacksonModule;
import io.trino.server.testing.TestingTrinoServer;
Expand All @@ -47,6 +50,7 @@
import static io.airlift.json.JsonCodec.listJsonCodec;
import static io.trino.client.ProtocolHeaders.TRINO_HEADERS;
import static io.trino.execution.QueryState.FAILED;
import static io.trino.execution.QueryState.FINISHING;
import static io.trino.execution.QueryState.RUNNING;
import static io.trino.server.TestQueryResource.BASIC_QUERY_INFO_CODEC;
import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.VIEW_QUERY;
Expand All @@ -71,11 +75,15 @@ public class TestQueryStateInfoResource
private TestingTrinoServer server;
private HttpClient client;
private QueryResults queryResults;
private QueryResults createCatalogResults;

@BeforeAll
public void setUp()
{
server = TestingTrinoServer.create();
server.installPlugin(new MockConnectorPlugin(MockConnectorFactory.builder()
.withSecuritySensitivePropertyNames(ImmutableSet.of("password"))
.build()));
server.installPlugin(new TpchPlugin());
server.createCatalog("tpch", "tpch");
client = new JettyHttpClient();
Expand All @@ -96,6 +104,19 @@ public void setUp()
QueryResults queryResults2 = client.execute(request2, createJsonResponseHandler(QUERY_RESULTS_JSON_CODEC));
client.execute(prepareGet().setUri(queryResults2.getNextUri()).build(), createJsonResponseHandler(QUERY_RESULTS_JSON_CODEC));

Request createCatalogRequest = preparePost()
.setUri(uriBuilderFrom(server.getBaseUrl()).replacePath("/v1/statement").build())
.setBodyGenerator(createStaticBodyGenerator("""
CREATE CATALOG test_catalog USING mock
WITH (
"user" = 'bob',
"password" = '1234'
)""", UTF_8))
.setHeader(TRINO_HEADERS.requestUser(), "catalogCreator")
.build();
createCatalogResults = client.execute(createCatalogRequest, createJsonResponseHandler(jsonCodec(QueryResults.class)));
client.execute(prepareGet().setUri(createCatalogResults.getNextUri()).build(), createJsonResponseHandler(QUERY_RESULTS_JSON_CODEC));

// queries are started in the background, so they may not all be immediately visible
long start = System.nanoTime();
while (Duration.nanosSince(start).compareTo(new Duration(5, MINUTES)) < 0) {
Expand All @@ -105,8 +126,8 @@ public void setUp()
.setHeader(TRINO_HEADERS.requestUser(), "unknown")
.build(),
createJsonResponseHandler(BASIC_QUERY_INFO_CODEC));
if (queryInfos.size() == 2) {
if (queryInfos.stream().allMatch(info -> info.getState() == RUNNING)) {
if (queryInfos.size() == 3) {
if (queryInfos.stream().allMatch(info -> info.getState() == RUNNING || info.getState() == FINISHING)) {
break;
}

Expand Down Expand Up @@ -143,7 +164,12 @@ public void testGetAllQueryStateInfos()
.build(),
createJsonResponseHandler(listJsonCodec(QueryStateInfo.class)));

assertThat(infos).hasSize(2);
assertThat(infos.size()).isEqualTo(3);
QueryStateInfo createCatalogInfo = infos.stream()
.filter(info -> info.getQueryId().getId().equals(createCatalogResults.getId()))
.findFirst()
.orElse(null);
assertCreateCatalogQueryIsRedacted(createCatalogInfo);
}

@Test
Expand Down Expand Up @@ -185,6 +211,19 @@ public void testGetQueryStateInfo()
assertThat(info).isNotNull();
}

@Test
public void testGetQueryStateInfoWithRedactedSecrets()
{
QueryStateInfo info = client.execute(
prepareGet()
.setUri(server.resolve("/v1/queryState/" + createCatalogResults.getId()))
.setHeader(TRINO_HEADERS.requestUser(), "unknown")
.build(),
createJsonResponseHandler(jsonCodec(QueryStateInfo.class)));

assertCreateCatalogQueryIsRedacted(info);
}

@Test
public void testGetAllQueryStateInfosDenied()
{
Expand All @@ -194,7 +233,7 @@ public void testGetAllQueryStateInfosDenied()
.setHeader(TRINO_HEADERS.requestUser(), "any-other-user")
.build(),
createJsonResponseHandler(listJsonCodec(QueryStateInfo.class)));
assertThat(infos).hasSize(2);
assertThat(infos).hasSize(3);

testGetAllQueryStateInfosDenied("user1", 1);
testGetAllQueryStateInfosDenied("any-other-user", 0);
Expand Down Expand Up @@ -249,4 +288,15 @@ public void testGetQueryStateInfoNo()
.isInstanceOf(UnexpectedResponseException.class)
.hasMessageMatching("Expected response code .*, but was 404");
}

private static void assertCreateCatalogQueryIsRedacted(QueryStateInfo info)
{
assertThat(info).isNotNull();
assertThat(info.getQuery()).isEqualTo("""
CREATE CATALOG test_catalog USING mock
WITH (
"user" = 'bob',
"password" = '***'
)""");
}
}
Loading

0 comments on commit 2490789

Please sign in to comment.