From 71eeb18003b3949a43557866ce68bd9b2837a9d9 Mon Sep 17 00:00:00 2001 From: Philippe Pepiot Date: Mon, 29 Feb 2016 11:38:16 +0100 Subject: [PATCH] Paramiko: Handle closed session When the connection is dropped (for instance if the target host is restarted between tests), paramiko raises "paramiko.ssh_exception.SSHException: SSH session not active". In this case we try to reconnect (once). --- testinfra/backend/paramiko.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/testinfra/backend/paramiko.py b/testinfra/backend/paramiko.py index a19fec73..27a051f3 100644 --- a/testinfra/backend/paramiko.py +++ b/testinfra/backend/paramiko.py @@ -1,5 +1,5 @@ # -*- coding: utf8 -*- -# Copyright © 2015 Philippe Pepiot +# Copyright © 2015-2016 Philippe Pepiot # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ HAS_PARAMIKO = False else: HAS_PARAMIKO = True + import paramiko.ssh_exception class IgnorePolicy(paramiko.MissingHostKeyPolicy): """Policy for ignoring missing host key.""" @@ -79,14 +80,27 @@ def client(self): self._client = client return self._client - def run(self, command, *args, **kwargs): - command = self.get_command(command, *args) + def _exec_command(self, command): chan = self.client.get_transport().open_session() - command = self.encode(command) chan.exec_command(command) rc = chan.recv_exit_status() stdout = b''.join(chan.makefile('rb')) stderr = b''.join(chan.makefile_stderr('rb')) + return rc, stdout, stderr + + def run(self, command, *args, **kwargs): + command = self.get_command(command, *args) + command = self.encode(command) + try: + rc, stdout, stderr = self._exec_command(command) + except paramiko.ssh_exception.SSHException: + if not self.client.get_transport().is_active(): + # try to reinit connection (once) + self._client = None + rc, stdout, stderr = self._exec_command(command) + else: + raise + result = base.CommandResult(self, rc, stdout, stderr, command) logger.info("RUN %s", result) return result