Skip to content

Commit

Permalink
Reconnect example (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
wboayue authored Oct 22, 2024
1 parent f629b3c commit 2f47394
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 70 deletions.
30 changes: 30 additions & 0 deletions examples/stream_retry.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use ibapi::contracts::Contract;
use ibapi::market_data::realtime::{BarSize, WhatToShow};
use ibapi::{Client, Error};

fn main() {
env_logger::init();

let connection_url = "127.0.0.1:4002";
let client = Client::connect(connection_url, 100).expect("connection to TWS failed!");

let contract = Contract::stock("AAPL");
stream_bars(&client, &contract);
}

// Request real-time bars data with 5-second intervals
fn stream_bars(client: &Client, contract: &Contract) {
let subscription = client
.realtime_bars(&contract, BarSize::Sec5, WhatToShow::Trades, false)
.expect("realtime bars request failed!");

for bar in &subscription {
// Process each bar here (e.g., print or use in calculations)
println!("bar: {bar:?}");
}

if let Some(Error::ConnectionReset) = subscription.error() {
println!("Connection reset. Retrying stream...");
stream_bars(client, contract);
}
}
115 changes: 49 additions & 66 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, AtomicI32, Ordering};
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::time::Duration;

use log::{debug, error, info, warn};
Expand Down Expand Up @@ -1141,14 +1141,15 @@ impl Debug for Client {
///
#[allow(private_bounds)]
pub struct Subscription<'a, T: Subscribable<T>> {
pub(crate) client: &'a Client,
pub(crate) request_id: Option<i32>,
pub(crate) order_id: Option<i32>,
pub(crate) message_type: Option<OutgoingMessages>,
pub(crate) phantom: PhantomData<T>,
client: &'a Client,
request_id: Option<i32>,
order_id: Option<i32>,
message_type: Option<OutgoingMessages>,
phantom: PhantomData<T>,
cancelled: AtomicBool,
subscription: InternalSubscription,
response_context: ResponseContext,
error: Arc<Mutex<Option<Error>>>,
}

// Extra metadata that might be need
Expand All @@ -1170,6 +1171,7 @@ impl<'a, T: Subscribable<T>> Subscription<'a, T> {
response_context: context,
phantom: PhantomData,
cancelled: AtomicBool::new(false),
error: Arc::new(Mutex::new(None)),
}
} else if let Some(order_id) = subscription.order_id {
Subscription {
Expand All @@ -1181,6 +1183,7 @@ impl<'a, T: Subscribable<T>> Subscription<'a, T> {
response_context: context,
phantom: PhantomData,
cancelled: AtomicBool::new(false),
error: Arc::new(Mutex::new(None)),
}
} else if let Some(message_type) = subscription.message_type {
Subscription {
Expand All @@ -1192,6 +1195,7 @@ impl<'a, T: Subscribable<T>> Subscription<'a, T> {
response_context: context,
phantom: PhantomData,
cancelled: AtomicBool::new(false),
error: Arc::new(Mutex::new(None)),
}
} else {
panic!("unsupported internal subscription: {:?}", subscription)
Expand All @@ -1200,36 +1204,40 @@ impl<'a, T: Subscribable<T>> Subscription<'a, T> {

/// Blocks until the item become available.
pub fn next(&self) -> Option<T> {
loop {
match self.subscription.next() {
Some(Ok(mut message)) => {
if T::RESPONSE_MESSAGE_IDS.contains(&message.message_type()) {
match T::decode(self.client.server_version(), &mut message) {
Ok(val) => return Some(val),
Err(err) => {
error!("error decoding execution data: {err}");
}
}
} else if message.message_type() == IncomingMessages::Error {
let error_message = message.peek_string(4);
error!("{error_message}");
return None;
} else {
info!("subscription iterator unexpected message: {message:?}");
}
}
Some(Err(Error::Cancelled)) => {
debug!("subscription cancelled");
return None;
}
Some(Err(Error::Shutdown)) => {
debug!("server disconnected");
return None;
}
_ => {
return None;
self.process_response(self.subscription.next())
}

fn process_response(&self, response: Option<Result<ResponseMessage, Error>>) -> Option<T> {
match response {
Some(Ok(message)) => self.process_message(message),
Some(Err(e)) => {
let mut error = self.error.lock().unwrap();
*error = Some(e);
None
}
None => None,
}
}

fn process_message(&self, mut message: ResponseMessage) -> Option<T> {
if T::RESPONSE_MESSAGE_IDS.contains(&message.message_type()) {
match T::decode(self.client.server_version(), &mut message) {
Ok(val) => Some(val),
Err(err) => {
let mut error = self.error.lock().unwrap();
*error = Some(err);
None
}
}
} else if message.message_type() == IncomingMessages::Error {
let error_message = message.peek_string(4);
error!("{error_message}");
let mut error = self.error.lock().unwrap();
*error = Some(Error::Simple(error_message));
None
} else {
info!("subscription iterator unexpected message: {message:?}");
None
}
}

Expand All @@ -1245,22 +1253,7 @@ impl<'a, T: Subscribable<T>> Subscription<'a, T> {
/// //}
/// ```
pub fn try_next(&self) -> Option<T> {
if let Some(Ok(mut message)) = self.subscription.try_next() {
if message.message_type() == IncomingMessages::Error {
error!("{}", message.peek_string(4));
return None;
}

match T::decode(self.client.server_version(), &mut message) {
Ok(val) => Some(val),
Err(err) => {
error!("error decoding message: {err}");
None
}
}
} else {
None
}
self.process_response(self.subscription.try_next())
}

/// To request the next bar in a non-blocking manner.
Expand All @@ -1275,22 +1268,7 @@ impl<'a, T: Subscribable<T>> Subscription<'a, T> {
/// //}
/// ```
pub fn next_timeout(&self, timeout: Duration) -> Option<T> {
if let Some(Ok(mut message)) = self.subscription.next_timeout(timeout) {
if message.message_type() == IncomingMessages::Error {
error!("{}", message.peek_string(4));
return None;
}

match T::decode(self.client.server_version(), &mut message) {
Ok(val) => Some(val),
Err(err) => {
error!("error decoding message: {err}");
None
}
}
} else {
None
}
self.process_response(self.subscription.next_timeout(timeout))
}

/// Cancel the subscription
Expand Down Expand Up @@ -1338,6 +1316,11 @@ impl<'a, T: Subscribable<T>> Subscription<'a, T> {
pub fn timeout_iter(&self, timeout: Duration) -> SubscriptionTimeoutIter<T> {
SubscriptionTimeoutIter { subscription: self, timeout }
}

pub fn error(&self) -> Option<Error> {
let error = self.error.lock().unwrap();
error.clone()
}
}

impl<'a, T: Subscribable<T>> Drop for Subscription<'a, T> {
Expand Down
2 changes: 2 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub enum Error {
ServerVersion(i32, i32, String),
Simple(String),
ConnectionFailed,
ConnectionReset,
Cancelled,
Shutdown,
}
Expand All @@ -35,6 +36,7 @@ impl std::fmt::Display for Error {
Error::Parse(i, value, message) => write!(f, "parse error: {i} - {value} - {message}"),
Error::ServerVersion(wanted, have, message) => write!(f, "server version {wanted} required, got {have}: {message}"),
Error::ConnectionFailed => write!(f, "ConnectionFailed"),
Error::ConnectionReset => write!(f, "ConnectionReset"),
Error::Cancelled => write!(f, "Cancelled"),
Error::Shutdown => write!(f, "Shutdown"),

Expand Down
24 changes: 20 additions & 4 deletions src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ impl SharedChannels {
fn contains_sender(&self, message_type: IncomingMessages) -> bool {
self.senders.contains_key(&message_type)
}

// Notify all senders with a given message
fn notify_all(&self, message: &Result<ResponseMessage, Error>) {
for sender in self.senders.values() {
let _ = sender.send(message.clone());
}
}
}

// Signals are used to notify the backend when a subscriber is dropped.
Expand Down Expand Up @@ -164,6 +171,7 @@ impl TcpMessageBus {

self.requests.notify_all(&Err(Error::Shutdown));
self.orders.notify_all(&Err(Error::Shutdown));
self.shared_channels.notify_all(&Err(Error::Shutdown));

self.requests.clear();
self.orders.clear();
Expand All @@ -172,7 +180,17 @@ impl TcpMessageBus {
self.shutdown_requested.store(true, Ordering::Relaxed);
}

fn reset(&self) {}
fn reset(&self) {
debug!("reset message bus");

self.requests.notify_all(&Err(Error::ConnectionReset));
self.orders.notify_all(&Err(Error::ConnectionReset));
self.shared_channels.notify_all(&Err(Error::ConnectionReset));

self.requests.clear();
self.orders.clear();
self.executions.clear();
}

fn clean_request(&self, request_id: i32) {
self.requests.remove(&request_id);
Expand Down Expand Up @@ -217,10 +235,8 @@ impl TcpMessageBus {
return;
}

message_bus.reset();

info!("successfully reconnected to TWS/Gateway");

message_bus.reset();
continue;
}
Err(err) => {
Expand Down

0 comments on commit 2f47394

Please sign in to comment.