Skip to content

Commit

Permalink
feat: allow for custom http clients to be provided
Browse files Browse the repository at this point in the history
The ability to customize clients makes it easier to integrate this library into applications more smoothly. Applications may wish to set custom `User-Agent` headers, configure proxies through ways other than Java system properties, hook in metrics collection, and more.
  • Loading branch information
nscuro committed Dec 3, 2023
1 parent ca0b299 commit d9b8297
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 42 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
* Copyright (c) 2023 Jeremy Long. All Rights Reserved.
*/
package io.github.jeremylong.openvulnerability.client;

import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient;
import org.apache.hc.client5.http.impl.async.HttpAsyncClients;
import org.apache.hc.client5.http.impl.routing.SystemDefaultRoutePlanner;

import java.net.ProxySelector;
import java.util.function.Supplier;

/**
* Supplier for {@link CloseableHttpAsyncClient}s.
* <p>
* May be used to provide customized HTTP clients to data source clients.
* <p>
* Closing of the supplied {@link CloseableHttpAsyncClient} instances is a responsibility of the caller.
*/
@FunctionalInterface
public interface HttpAsyncClientSupplier extends Supplier<CloseableHttpAsyncClient> {

static HttpAsyncClientSupplier getDefault() {
return () -> {
SystemDefaultRoutePlanner planner = new SystemDefaultRoutePlanner(ProxySelector.getDefault());
return HttpAsyncClients.custom()
.setRoutePlanner(planner)
.useSystemProperties()
.build();
};
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
* Copyright (c) 2023 Jeremy Long. All Rights Reserved.
*/
package io.github.jeremylong.openvulnerability.client;

import org.apache.hc.client5.http.impl.classic.CloseableHttpClient;
import org.apache.hc.client5.http.impl.classic.HttpClientBuilder;
import org.apache.hc.client5.http.impl.routing.SystemDefaultRoutePlanner;

import java.net.ProxySelector;
import java.util.function.Supplier;

/**
* Supplier for {@link CloseableHttpClient}s.
* <p>
* May be used to provide customized HTTP clients to data source clients.
* <p>
* Closing of the supplied {@link CloseableHttpClient} instances is a responsibility of the caller.
*/
@FunctionalInterface
public interface HttpClientSupplier extends Supplier<CloseableHttpClient> {

static HttpClientSupplier getDefault() {
return () -> {
SystemDefaultRoutePlanner planner = new SystemDefaultRoutePlanner(ProxySelector.getDefault());
return HttpClientBuilder.create()
.setRoutePlanner(planner)
.useSystemProperties()
.build();
};
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
package io.github.jeremylong.openvulnerability.client.epss;

import io.github.jeremylong.openvulnerability.client.DataFeed;
import io.github.jeremylong.openvulnerability.client.HttpClientSupplier;
import org.apache.hc.client5.http.classic.methods.HttpGet;
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient;
import org.apache.hc.client5.http.impl.classic.HttpClientBuilder;
import org.apache.hc.client5.http.impl.routing.SystemDefaultRoutePlanner;

import java.io.IOException;
import java.net.ProxySelector;
import java.util.List;

/**
Expand All @@ -35,23 +33,27 @@
public class EpssDataFeed implements DataFeed<List<EpssItem>> {
private final static String DEFAULT_LOCATION = "https://epss.cyentia.com/epss_scores-current.csv.gz";

private final HttpClientSupplier httpClientSupplier;
private final String downloadUrl;

public EpssDataFeed() {
this.downloadUrl = DEFAULT_LOCATION;
this(DEFAULT_LOCATION);
}

public EpssDataFeed(String downloadUrl) {
this(downloadUrl, null);
}

public EpssDataFeed(String downloadUrl, HttpClientSupplier httpClientSupplier) {
this.downloadUrl = downloadUrl;
this.httpClientSupplier = httpClientSupplier != null ? httpClientSupplier : HttpClientSupplier.getDefault();
}

@Override
public List<EpssItem> download() {
List<EpssItem> list = null;
HttpGet request = new HttpGet(downloadUrl);
SystemDefaultRoutePlanner planner = new SystemDefaultRoutePlanner(ProxySelector.getDefault());
try (CloseableHttpClient client = HttpClientBuilder.create().setRoutePlanner(planner).useSystemProperties()
.build()) {
try (CloseableHttpClient client = httpClientSupplier.get()) {
list = client.execute(request, new EpssResponseHandler());
} catch (IOException e) {
e.printStackTrace();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.samskivert.mustache.Mustache;
import com.samskivert.mustache.Template;
import io.github.jeremylong.openvulnerability.client.HttpAsyncClientSupplier;
import io.github.jeremylong.openvulnerability.client.PagedDataSource;
import org.apache.hc.client5.http.async.methods.SimpleHttpRequest;
import org.apache.hc.client5.http.async.methods.SimpleHttpResponse;
import org.apache.hc.client5.http.async.methods.SimpleRequestBuilder;
import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient;
import org.apache.hc.client5.http.impl.async.HttpAsyncClients;
import org.apache.hc.client5.http.impl.routing.SystemDefaultRoutePlanner;
import org.apache.hc.core5.http.ContentType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -37,7 +36,6 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.ProxySelector;
import java.nio.charset.StandardCharsets;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
Expand Down Expand Up @@ -141,7 +139,7 @@ public class GitHubSecurityAdvisoryClient implements PagedDataSource<SecurityAdv
* @param githubToken the GitHub API Token.
*/
public GitHubSecurityAdvisoryClient(String githubToken) {
this(githubToken, GITHUB_GRAPHQL_ENDPOINT);
this(githubToken, null);
}

/**
Expand All @@ -151,14 +149,29 @@ public GitHubSecurityAdvisoryClient(String githubToken) {
* @param githubToken the GitHub API Token.
*/
public GitHubSecurityAdvisoryClient(String githubToken, String endpoint) {
this(githubToken, endpoint, null);
}

/**
* Constructs a new client.
*
* @param githubToken the GitHub API Token.
* @param endpoint the GraphQL endpoint of GitHub or GHE.
* @param httpClientSupplier supplier for custom HTTP clients; if {@code null} a default client will be used
*/
public GitHubSecurityAdvisoryClient(String githubToken, String endpoint,
HttpAsyncClientSupplier httpClientSupplier) {
this.githubToken = githubToken;
this.endpoint = endpoint;
this.endpoint = endpoint != null ? endpoint : GITHUB_GRAPHQL_ENDPOINT;
advistoriesTemplate = loadMustacheTemplate(ADVISORIES_TEMPLATE);
vulnerabilitiesTemplate = loadMustacheTemplate(VULNERABILITIES_TEMPLATE);
cwesTemplate = loadMustacheTemplate(CWES_TEMPLATE);

SystemDefaultRoutePlanner planner = new SystemDefaultRoutePlanner(ProxySelector.getDefault());
httpClient = HttpAsyncClients.custom().setRoutePlanner(planner).useSystemProperties().build();
if (httpClientSupplier == null) {
httpClient = HttpAsyncClientSupplier.getDefault().get();
} else {
httpClient = httpClientSupplier.get();
}
httpClient.start();
objectMapper = new ObjectMapper();
objectMapper.registerModule(new JavaTimeModule());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package io.github.jeremylong.openvulnerability.client.ghsa;

import io.github.jeremylong.openvulnerability.client.HttpAsyncClientSupplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -61,6 +62,7 @@ public final class GitHubSecurityAdvisoryClientBuilder {
* The publishedSince filter.
*/
private ZonedDateTime publishedSince;
private HttpAsyncClientSupplier httpClientSupplier;

/**
* Private constructor for a builder.
Expand Down Expand Up @@ -132,18 +134,24 @@ public GitHubSecurityAdvisoryClientBuilder withPublishedSinceFilter(ZonedDateTim
return this;
}

/**
* Provide a supplier for custom HTTP clients.
*
* @param httpClientSupplier supplier for custom HTTP clients; if {@code null} a default client will be used
* @return the builder
*/
public GitHubSecurityAdvisoryClientBuilder withHttpClientSupplier(HttpAsyncClientSupplier httpClientSupplier) {
this.httpClientSupplier = httpClientSupplier;
return this;
}

/**
* Build the GitHub SecurityAdvisory GraphQL API client.
*
* @return the GitHub SecurityAdvisory GraphQL API client
*/
public GitHubSecurityAdvisoryClient build() {
GitHubSecurityAdvisoryClient client;
if (endpoint == null) {
client = new GitHubSecurityAdvisoryClient(apiKey);
} else {
client = new GitHubSecurityAdvisoryClient(apiKey, endpoint);
}
GitHubSecurityAdvisoryClient client = new GitHubSecurityAdvisoryClient(apiKey, endpoint, httpClientSupplier);
if (classifications != null) {
client.setClassifications(classifications);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import io.github.jeremylong.openvulnerability.client.DataFeed;
import io.github.jeremylong.openvulnerability.client.HttpClientSupplier;
import org.apache.hc.client5.http.classic.methods.HttpGet;
import org.apache.hc.client5.http.impl.classic.BasicHttpClientResponseHandler;
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient;
import org.apache.hc.client5.http.impl.classic.HttpClientBuilder;
import org.apache.hc.client5.http.impl.routing.SystemDefaultRoutePlanner;

import java.io.IOException;
import java.net.ProxySelector;

/**
* Data Feed for the CISA Known Exploited Vulnerabilities Catalog.
Expand All @@ -41,25 +39,29 @@ public class KevDataFeed implements DataFeed<KevCatalog> {
* Jackson object mapper.
*/
private final ObjectMapper objectMapper;
private final HttpClientSupplier httpClientSupplier;
private final String downloadUrl;

public KevDataFeed() {
this(DEFAULT_LOCATION);
}

public KevDataFeed(String downloadUrl) {
this(downloadUrl, null);
}

public KevDataFeed(String downloadUrl, HttpClientSupplier httpClientSupplier) {
this.downloadUrl = downloadUrl;
this.httpClientSupplier = httpClientSupplier != null ? httpClientSupplier : HttpClientSupplier.getDefault();
objectMapper = new ObjectMapper();
objectMapper.registerModule(new JavaTimeModule());
}

@Override
public KevCatalog download() {
HttpGet request = new HttpGet(downloadUrl);
SystemDefaultRoutePlanner planner = new SystemDefaultRoutePlanner(ProxySelector.getDefault());
String json;
try (CloseableHttpClient client = HttpClientBuilder.create().setRoutePlanner(planner).useSystemProperties()
.build()) {
try (CloseableHttpClient client = httpClientSupplier.get()) {
json = client.execute(request, new BasicHttpClientResponseHandler());
} catch (IOException e) {
throw new KevException("Unable to download the Known Exploitable Vulnerability Catalog", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import io.github.jeremylong.openvulnerability.client.HttpAsyncClientSupplier;
import io.github.jeremylong.openvulnerability.client.PagedDataSource;
import org.apache.hc.client5.http.async.methods.SimpleHttpRequest;
import org.apache.hc.client5.http.async.methods.SimpleHttpResponse;
Expand Down Expand Up @@ -133,7 +134,7 @@ public class NvdCveClient implements PagedDataSource<DefCveItem> {
* @param maxPageCount the maximum number of pages to retrieve from the NVD API.
*/
NvdCveClient(String apiKey, String endpoint, int threadCount, int maxPageCount) {
this(apiKey, endpoint, apiKey == null ? 6500 : 600, threadCount, maxPageCount, 10);
this(apiKey, endpoint, 0, threadCount, maxPageCount, 10, null);
}

/**
Expand All @@ -146,7 +147,7 @@ public class NvdCveClient implements PagedDataSource<DefCveItem> {
* @param maxRetryCount the maximum number of retries for 503 and 429 status code responses.
*/
NvdCveClient(String apiKey, String endpoint, int threadCount, int maxPageCount, int maxRetryCount) {
this(apiKey, endpoint, apiKey == null ? 6500 : 600, threadCount, maxPageCount, maxRetryCount);
this(apiKey, endpoint, 0, threadCount, maxPageCount, maxRetryCount, null);
}

/**
Expand All @@ -158,8 +159,10 @@ public class NvdCveClient implements PagedDataSource<DefCveItem> {
* @param threadCount the number of threads to use when calling the NVD API.
* @param maxPageCount the maximum number of pages to retrieve from the NVD API.
* @param maxRetryCount the maximum number of retries for 503 and 429 status code responses.
* @param httpClientSupplier supplier for custom HTTP clients; if {@code null} a default client will be used
*/
NvdCveClient(String apiKey, String endpoint, long delay, int threadCount, int maxPageCount, int maxRetryCount) {
NvdCveClient(String apiKey, String endpoint, long delay, int threadCount, int maxPageCount, int maxRetryCount,
HttpAsyncClientSupplier httpClientSupplier) {
this.apiKey = apiKey;
if (endpoint == null) {
this.endpoint = DEFAULT_ENDPOINT;
Expand All @@ -186,8 +189,11 @@ public class NvdCveClient implements PagedDataSource<DefCveItem> {
meter = new RateMeter(50, 32500);
}
clients = new ArrayList<>(threadCount);
if (delay == 0) {
delay = apiKey == null ? 6500 : 600;
}
for (int i = 0; i < threadCount; i++) {
clients.add(new RateLimitedClient(maxRetryCount, delay, meter));
clients.add(new RateLimitedClient(maxRetryCount, delay, meter, httpClientSupplier));
}
objectMapper = new ObjectMapper();
objectMapper.registerModule(new JavaTimeModule());
Expand Down
Loading

0 comments on commit d9b8297

Please sign in to comment.