virtual-deployment/virtualbox/pybox/utils/tests/test_serial.py

182 lines
4.9 KiB
Python

import unittest
from unittest.mock import MagicMock, patch, ANY
import serial
import socket
class ConnectTestCase(unittest.TestCase):
"""
Class to test connect method
"""
@patch("serial.LOG.info")
@patch("socket.socket")
def test_connect_windows(self, mock_socket, mock_log_info):
"""
Test connect method
"""
# Setup
mock_socket.return_value = mock_socket
hostname = 'hostname'
port = 10000
# Run
result = serial.connect(hostname, port)
# Assert
mock_socket.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
mock_socket.connect.assert_called_once_with(('localhost', port))
self.assertEqual(result, mock_socket)
@patch("serial.LOG.info")
@patch("socket.socket")
def test_connect_fail(self, mock_socket, mock_log_info):
"""
Test connect method when connection fails
"""
# Setup
mock_socket.return_value = mock_socket
hostname = 'hostname'
port = 10000
mock_socket.connect.side_effect = Exception
# Run
result = serial.connect(hostname, port)
# Assert
mock_socket.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
mock_socket.connect.assert_called_once_with(('localhost', port))
mock_log_info.assert_called_with("Failed sock connection")
self.assertIsNone(result)
class DisconnectTestCase(unittest.TestCase):
"""
Class to test disconnect method
"""
@patch("serial.LOG.info")
def test_disconnect(self, mock_log_info):
"""
Test disconnect method
"""
# Setup
sock = MagicMock()
# Run
serial.disconnect(sock)
# Assert
sock.shutdown.assert_called_once_with(socket.SHUT_RDWR)
sock.close.assert_called_once()
mock_log_info.assert_any_call(ANY)
# TODO This test is just for coverage purposes, this function needs a heavy refactoring
class GetOutputTestCase(unittest.TestCase):
"""
Class to test get_output method
"""
@patch("serial.LOG.info")
@patch("serial.time")
def test_get_output(self, mock_time, mock_log_info):
"""
Test get_output method
"""
# Setup
stream = MagicMock()
stream.poll.return_value = None
stream.gettimeout.return_value = 1
stream.recv.side_effect = ['cmd\n', 'test\n', ':~$ ']
mock_time.time.side_effect = [0, 1, 2, 3]
cmd = "cmd"
prompts = [':~$ ', ':~# ', ':/home/wrsroot# ', '(keystone_.*)]$ ', '(keystone_.*)]# ']
timeout = 2
log = True
as_lines = True
flush = True
# Run
with self.assertRaises(Exception):
serial.get_output(stream, cmd, prompts, timeout, log, as_lines, flush)
# Assert
stream.sendall.assert_called_once_with(f"{cmd}\n".encode('utf-8'))
mock_log_info.assert_any_call('cmd')
mock_log_info.assert_any_call('test')
class ExpectBytesTestCase(unittest.TestCase):
"""
Class to test expect_bytes method
"""
@patch("serial.LOG.debug")
@patch("serial.LOG.info")
@patch("serial.stdout.write")
def test_expect_bytes(self, mock_stdout_write, mock_log_info, mock_log_debug):
"""
Test expect_bytes method
"""
# Setup
stream = MagicMock()
stream.expect_bytes.return_value = None
stream.poll.return_value = None
text = "Hello, world!"
timeout = 180
fail_ok = False
flush = True
# Run
result = serial.expect_bytes(stream, text, timeout, fail_ok, flush)
# Assert
self.assertEqual(result, 0)
stream.expect_bytes.assert_called_once_with(f"{text}".encode('utf-8'), timeout=timeout)
mock_stdout_write.assert_any_call('\n')
mock_log_info.assert_any_call("Expecting text within %s minutes: %s\n", timeout / 60, text)
mock_log_debug.assert_any_call("Found expected text: %s", text)
class SendBytesTestCase(unittest.TestCase):
"""
Class to test send_bytes method
"""
@patch("serial.LOG.info")
@patch("serial.expect_bytes")
def test_send_bytes(self, mock_expect_bytes, mock_log_info):
"""
Test send_bytes method
"""
# Setup
stream = MagicMock()
stream.poll.return_value = None
text = "Hello, world!"
fail_ok = False
expect_prompt = True
prompt = None
timeout = 180
send = True
flush = True
mock_expect_bytes.return_value = 0
# Run
result = serial.send_bytes(stream, text, fail_ok, expect_prompt, prompt, timeout, send, flush)
# Assert
self.assertEqual(result, 0)
mock_expect_bytes.assert_called()
stream.sendall.assert_called_once_with(f"{text}\n".encode('utf-8'))
if __name__ == '__main__':
unittest.main()