Skip to content

Commit

Permalink
Merge pull request #725 from ikolomi/add_clientname_support
Browse files Browse the repository at this point in the history
Add python/core support for configuring client name during connection…
  • Loading branch information
ikolomi authored Dec 28, 2023
2 parents 60ac1c7 + f83258f commit e5bd265
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 21 deletions.
2 changes: 2 additions & 0 deletions examples/python/client_example.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ async def test_standalone_client(host: str = "localhost", port: int = 6379):
# Check `RedisClientConfiguration/ClusterClientConfiguration` for additional options.
config = BaseClientConfiguration(
addresses=addresses,
client_name = 'test_standalone_client'
# use_tls=True
)
client = await RedisClient.create(config)
Expand All @@ -57,6 +58,7 @@ async def test_cluster_client(host: str = "localhost", port: int = 6379):
# Check `RedisClientConfiguration/ClusterClientConfiguration` for additional options.
config = BaseClientConfiguration(
addresses=addresses,
client_name = 'test_cluster_client'
# use_tls=True
)
client = await RedisClusterClient.create(config)
Expand Down
32 changes: 18 additions & 14 deletions glide-core/src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::connection_request::{
AuthenticationInfo, ConnectionRequest, NodeAddress, ProtocolVersion, ReadFrom, TlsMode,
ConnectionRequest, NodeAddress, ProtocolVersion, ReadFrom, TlsMode,
};
use futures::FutureExt;
use logger_core::log_info;
Expand Down Expand Up @@ -44,22 +44,21 @@ pub fn convert_to_redis_protocol(protocol: ProtocolVersion) -> redis::ProtocolVe
}

pub(super) fn get_redis_connection_info(
authentication_info: Option<Box<AuthenticationInfo>>,
database_id: u32,
protocol: ProtocolVersion,
connection_request: &ConnectionRequest,
) -> redis::RedisConnectionInfo {
let protocol = convert_to_redis_protocol(protocol);
match authentication_info {
let protocol = convert_to_redis_protocol(connection_request.protocol.enum_value_or_default());
match connection_request.authentication_info.0.as_ref() {
Some(info) => redis::RedisConnectionInfo {
db: database_id as i64,
db: connection_request.database_id as i64,
username: chars_to_string_option(&info.username),
password: chars_to_string_option(&info.password),
protocol,
client_name: None,
client_name: chars_to_string_option(&connection_request.client_name),
},
None => redis::RedisConnectionInfo {
db: database_id as i64,
db: connection_request.database_id as i64,
protocol,
client_name: chars_to_string_option(&connection_request.client_name),
..Default::default()
},
}
Expand Down Expand Up @@ -271,9 +270,7 @@ async fn create_cluster_client(
) -> RedisResult<redis::cluster_async::ClusterConnection> {
// TODO - implement timeout for each connection attempt
let tls_mode = request.tls_mode.enum_value_or_default();
let protocol = request.protocol.enum_value_or_default();
let redis_connection_info =
get_redis_connection_info(request.authentication_info.0, 0, protocol);
let redis_connection_info = get_redis_connection_info(&request);
let initial_nodes: Vec<_> = request
.addresses
.into_iter()
Expand All @@ -286,7 +283,12 @@ async fn create_cluster_client(
if read_from_replicas {
builder = builder.read_from_replicas();
}
builder = builder.use_protocol(convert_to_redis_protocol(protocol));
builder = builder.use_protocol(convert_to_redis_protocol(
request.protocol.enum_value_or_default(),
));
if let Some(client_name) = redis_connection_info.client_name {
builder = builder.client_name(client_name);
}
if tls_mode != TlsMode::NoTls {
let tls = if tls_mode == TlsMode::SecureTls {
redis::cluster::TlsMode::Secure
Expand Down Expand Up @@ -353,9 +355,11 @@ fn sanitized_request_string(request: &ConnectionRequest) -> String {
}
None => String::new(),
};
let protocol = request.protocol.enum_value_or_default();
let client_name = chars_to_string_option(&request.client_name);

format!(
"\naddresses: {addresses}\nTLS mode: {tls_mode:?}\ncluster mode: {cluster_mode}{request_timeout}\nRead from replica strategy: {rfr_strategy:?}{database_id}{connection_retry_strategy}",
"\nAddresses: {addresses}\nTLS mode: {tls_mode:?}\nCluster mode: {cluster_mode}\nRequest timeout: {request_timeout}\nRead from replica strategy: {rfr_strategy:?}\nConnection retry strategy: {connection_retry_strategy}\nDatabase id: {database_id}\nProtocol: {protocol:?}\nClient name: {client_name:?}",
)
}

Expand Down
6 changes: 1 addition & 5 deletions glide-core/src/client/standalone_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,7 @@ impl StandaloneClient {
return Err(StandaloneClientConnectionError::NoAddressesProvided);
}
let retry_strategy = RetryStrategy::new(&connection_request.connection_retry_strategy.0);
let redis_connection_info = get_redis_connection_info(
connection_request.authentication_info.0,
connection_request.database_id,
connection_request.protocol.enum_value_or_default(),
);
let redis_connection_info = get_redis_connection_info(&connection_request);

let tls_mode = connection_request.tls_mode.enum_value_or_default();
let node_count = connection_request.addresses.len();
Expand Down
1 change: 1 addition & 0 deletions glide-core/src/protobuf/connection_request.proto
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ message ConnectionRequest {
AuthenticationInfo authentication_info = 7;
uint32 database_id = 8;
ProtocolVersion protocol = 9;
string client_name = 10;
}

message ConnectionRetryStrategy {
Expand Down
47 changes: 47 additions & 0 deletions glide-core/tests/test_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,53 @@ mod shared_client_tests {
});
}

#[rstest]
#[timeout(SHORT_CLUSTER_TEST_TIMEOUT)]
fn test_client_name_after_reconnection(#[values(false, true)] use_cluster: bool) {
const CLIENT_NAME: &str = "TEST_CLIENT_NAME";
let mut client_info_cmd = redis::Cmd::new();
client_info_cmd.arg("CLIENT").arg("INFO");
block_on_all(async move {
let test_basics = setup_test_basics(
use_cluster,
TestConfiguration {
shared_server: true,
client_name: Some(CLIENT_NAME.to_string()),
..Default::default()
},
)
.await;
let mut client = test_basics.client;
let client_info: String = redis::from_redis_value(
&client.send_command(&client_info_cmd, None).await.unwrap(),
)
.unwrap();
assert!(client_info.contains(&format!("name={CLIENT_NAME}")));

kill_connection(&mut client).await;

let error = client.send_command(&client_info_cmd, None).await;
// In Standalone mode the error is passed back to the client,
// while in Cluster mode the request is retried with reconnect
if !use_cluster {
assert!(error.is_err(), "{error:?}",);
let error = error.unwrap_err();
assert!(
error.is_connection_dropped() || error.is_timeout(),
"{error:?}",
);
}
let client_info: String = repeat_try_create(|| async {
let mut client = client.clone();
redis::from_redis_value(&client.send_command(&client_info_cmd, None).await.unwrap())
.ok()
})
.await;

assert!(client_info.contains(&format!("name={CLIENT_NAME}")));
});
}

#[rstest]
#[timeout(SHORT_CLUSTER_TEST_TIMEOUT)]
fn test_request_transaction_and_convert_all_values(#[values(false, true)] use_cluster: bool) {
Expand Down
10 changes: 8 additions & 2 deletions glide-core/tests/utilities/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use redis::{
};
use socket2::{Domain, Socket, Type};
use std::{
env, fs, io, net::SocketAddr, net::TcpListener, path::PathBuf, process, sync::Mutex,
time::Duration,
env, fs, io, net::SocketAddr, net::TcpListener, ops::Deref, path::PathBuf, process,
sync::Mutex, time::Duration,
};
use tempfile::TempDir;

Expand Down Expand Up @@ -643,6 +643,11 @@ pub fn create_connection_request(
configuration.connection_info.clone().unwrap_or_default(),
&mut connection_request,
);

if let Some(client_name) = &configuration.client_name {
connection_request.client_name = client_name.deref().into();
}

connection_request
}

Expand All @@ -656,6 +661,7 @@ pub struct TestConfiguration {
pub shared_server: bool,
pub read_from: Option<connection_request::ReadFrom>,
pub database_id: u32,
pub client_name: Option<String>,
pub protocol: ProtocolVersion,
}

Expand Down
11 changes: 11 additions & 0 deletions python/python/glide/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
credentials: Optional[RedisCredentials] = None,
read_from: ReadFrom = ReadFrom.PRIMARY,
request_timeout: Optional[int] = None,
client_name: Optional[str] = None,
):
"""
Represents the configuration settings for a Redis client.
Expand All @@ -106,12 +107,14 @@ def __init__(
request_timeout (Optional[int]): The duration in milliseconds that the client should wait for a request to complete.
This duration encompasses sending the request, awaiting for a response from the server, and any required reconnections or retries.
If the specified timeout is exceeded for a pending request, it will result in a timeout error. If not set, a default value will be used.
client_name (Optional[str]): Client name to be used for the client. Will be used with CLIENT SETNAME command during connection establishment.
"""
self.addresses = addresses or [NodeAddress()]
self.use_tls = use_tls
self.credentials = credentials
self.read_from = read_from
self.request_timeout = request_timeout
self.client_name = client_name

def _create_a_protobuf_conn_request(
self, cluster_mode: bool = False
Expand Down Expand Up @@ -139,6 +142,8 @@ def _create_a_protobuf_conn_request(
if self.credentials.username:
request.authentication_info.username = self.credentials.username
request.authentication_info.password = self.credentials.password
if self.client_name:
request.client_name = self.client_name
request.protocol = SentProtocolVersion.RESP2

return request
Expand Down Expand Up @@ -169,6 +174,7 @@ class RedisClientConfiguration(BaseClientConfiguration):
connection failures.
If not set, a default backoff strategy will be used.
database_id (Optional[Int]): index of the logical database to connect to.
client_name (Optional[str]): Client name to be used for the client. Will be used with CLIENT SETNAME command during connection establishment.
"""

def __init__(
Expand All @@ -180,13 +186,15 @@ def __init__(
request_timeout: Optional[int] = None,
reconnect_strategy: Optional[BackoffStrategy] = None,
database_id: Optional[int] = None,
client_name: Optional[str] = None,
):
super().__init__(
addresses=addresses,
use_tls=use_tls,
credentials=credentials,
read_from=read_from,
request_timeout=request_timeout,
client_name=client_name,
)
self.reconnect_strategy = reconnect_strategy
self.database_id = database_id
Expand Down Expand Up @@ -229,6 +237,7 @@ class ClusterClientConfiguration(BaseClientConfiguration):
request_timeout (Optional[int]): The duration in milliseconds that the client should wait for a request to complete.
This duration encompasses sending the request, awaiting for a response from the server, and any required reconnections or retries.
If the specified timeout is exceeded for a pending request, it will result in a timeout error. If not set, a default value will be used.
client_name (Optional[str]): Client name to be used for the client. Will be used with CLIENT SETNAME command during connection establishment.
Notes:
Currently, the reconnection strategy in cluster mode is not configurable, and exponential backoff
Expand All @@ -242,11 +251,13 @@ def __init__(
credentials: Optional[RedisCredentials] = None,
read_from: ReadFrom = ReadFrom.PRIMARY,
request_timeout: Optional[int] = None,
client_name: Optional[str] = None,
):
super().__init__(
addresses=addresses,
use_tls=use_tls,
credentials=credentials,
read_from=read_from,
request_timeout=request_timeout,
client_name=client_name,
)
3 changes: 3 additions & 0 deletions python/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ async def create_client(
credentials: Optional[RedisCredentials] = None,
database_id: int = 0,
addresses: Optional[List[NodeAddress]] = None,
client_name: Optional[str] = None,
) -> Union[RedisClient, RedisClusterClient]:
# Create async socket client
use_tls = request.config.getoption("--tls")
Expand All @@ -118,6 +119,7 @@ async def create_client(
addresses=seed_nodes if addresses is None else addresses,
use_tls=use_tls,
credentials=credentials,
client_name=client_name,
)
return await RedisClusterClient.create(cluster_config)
else:
Expand All @@ -129,5 +131,6 @@ async def create_client(
use_tls=use_tls,
credentials=credentials,
database_id=database_id,
client_name=client_name,
)
return await RedisClient.create(config)
8 changes: 8 additions & 0 deletions python/python/tests/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,14 @@ async def test_select_standalone_database_id(self, request):
client_info = await redis_client.custom_command(["CLIENT", "INFO"])
assert "db=4" in client_info

@pytest.mark.parametrize("cluster_mode", [True, False])
async def test_client_name(self, request, cluster_mode):
redis_client = await create_client(
request, cluster_mode=cluster_mode, client_name="TEST_CLIENT_NAME"
)
client_info = await redis_client.custom_command(["CLIENT", "INFO"])
assert "name=TEST_CLIENT_NAME" in client_info


@pytest.mark.asyncio
class TestCommands:
Expand Down
3 changes: 3 additions & 0 deletions python/python/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,20 @@ def test_default_client_config():
assert config.addresses[0].port == 6379
assert config.read_from.value == ProtobufReadFrom.Primary
assert config.use_tls is False
assert config.client_name is None


def test_convert_to_protobuf():
config = BaseClientConfiguration(
[NodeAddress("127.0.0.1")],
use_tls=True,
read_from=ReadFrom.PREFER_REPLICA,
client_name="TEST_CLIENT_NAME",
)
request = config._create_a_protobuf_conn_request()
assert isinstance(request, ConnectionRequest)
assert request.addresses[0].host == "127.0.0.1"
assert request.addresses[0].port == 6379
assert request.tls_mode is TlsMode.SecureTls
assert request.read_from == ProtobufReadFrom.PreferReplica
assert request.client_name == "TEST_CLIENT_NAME"

0 comments on commit e5bd265

Please sign in to comment.