Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_send_data_with_gap_1_retransmit(self):
sent_tsns = []
async def mock_send_chunk(chunk):
sent_tsns.append(chunk.tsn)
client = RTCSctpTransport(self.client_transport)
client._last_sacked_tsn = 4294967295
client._local_tsn = 0
client._ssthresh = 131072
client._send_chunk = mock_send_chunk
# queue 8 chunks, but cwnd only allows 3
with self.assertTimerRestarted(client):
run(client._send(123, 456, b"M" * USERDATA_MAX_LENGTH * 8))
self.assertEqual(client._cwnd, 3600)
self.assertEqual(client._fast_recovery_exit, None)
self.assertEqual(client._flight_size, 3600)
self.assertEqual(sent_tsns, [0, 1, 2])
self.assertEqual(outstanding_tsns(client), [0, 1, 2])
self.assertEqual(queued_tsns(client), [3, 4, 5, 6, 7])
def test_t3_expired(self):
async def mock_send_chunk(chunk):
pass
async def mock_transmit():
pass
client = RTCSctpTransport(self.client_transport)
client._local_tsn = 0
client._send_chunk = mock_send_chunk
# 1 chunk
run(client._send(123, 456, b"M" * USERDATA_MAX_LENGTH))
self.assertIsNotNone(client._t3_handle)
self.assertEqual(outstanding_tsns(client), [0])
self.assertEqual(queued_tsns(client), [])
# t3 expires
client._transmit = mock_transmit
client._t3_expired()
self.assertIsNone(client._t3_handle)
self.assertEqual(outstanding_tsns(client), [0])
self.assertEqual(queued_tsns(client), [])
for chunk in client._outbound_queue:
def test_connect_then_client_creates_negotiated_data_channel_without_id(self):
client = RTCSctpTransport(self.client_transport)
self.assertFalse(client.is_server)
server = RTCSctpTransport(self.server_transport)
self.assertTrue(server.is_server)
# connect
run(server.start(client.getCapabilities(), client.port))
run(client.start(server.getCapabilities(), server.port))
# check outcome
run(wait_for_outcome(client, server))
self.assertEqual(client._association_state, RTCSctpTransport.State.ESTABLISHED)
self.assertEqual(client._inbound_streams_count, 65535)
self.assertEqual(client._outbound_streams_count, 65535)
self.assertEqual(client._remote_extensions, [192, 130])
self.assertEqual(server._association_state, RTCSctpTransport.State.ESTABLISHED)
self.assertEqual(server._inbound_streams_count, 65535)
def test_receive_shutdown(self):
async def mock_send_chunk(chunk):
pass
client = RTCSctpTransport(self.client_transport)
client._last_received_tsn = 0
client._send_chunk = mock_send_chunk
client._set_state(RTCSctpTransport.State.ESTABLISHED)
# receive shutdown
chunk = ShutdownChunk()
chunk.cumulative_tsn = tsn_minus_one(client._last_sacked_tsn)
run(client._receive_chunk(chunk))
self.assertEqual(
client._association_state, RTCSctpTransport.State.SHUTDOWN_ACK_SENT
)
# receive shutdown complete
chunk = ShutdownCompleteChunk()
run(client._receive_chunk(chunk))
self.assertEqual(client._association_state, RTCSctpTransport.State.CLOSED)
def test_construct(self):
sctpTransport = RTCSctpTransport(self.client_transport)
self.assertEqual(sctpTransport.transport, self.client_transport)
self.assertEqual(sctpTransport.port, 5000)
def test_bad_verification_tag(self):
# verification tag is 12345 instead of 0
data = load("sctp_init_bad_verification.bin")
server = RTCSctpTransport(self.server_transport)
run(server.start(RTCSctpCapabilities(maxMessageSize=65536), 5000))
asyncio.ensure_future(self.client_transport._send_data(data))
# check outcome
run(asyncio.sleep(0.1))
self.assertEqual(server._association_state, RTCSctpTransport.State.CLOSED)
# shutdown
run(server.stop())
def test_connect_broken_transport(self):
"""
Transport with 100% loss never connects.
"""
loss_pattern = [True]
self.client_transport.transport._connection.loss_pattern = loss_pattern
self.server_transport.transport._connection.loss_pattern = loss_pattern
client = RTCSctpTransport(self.client_transport)
client._rto = 0.1
self.assertFalse(client.is_server)
server = RTCSctpTransport(self.server_transport)
server._rto = 0.1
self.assertTrue(server.is_server)
# connect
run(server.start(client.getCapabilities(), client.port))
run(client.start(server.getCapabilities(), server.port))
# check outcome
run(wait_for_outcome(client, server))
self.assertEqual(client._association_state, RTCSctpTransport.State.CLOSED)
self.assertEqual(client.state, "closed")
self.assertEqual(server._association_state, RTCSctpTransport.State.CLOSED)
self.assertEqual(server.state, "connecting")
# shutdown
run(client.stop())
def test_connect_server_limits_streams(self):
client = RTCSctpTransport(self.client_transport)
self.assertFalse(client.is_server)
server = RTCSctpTransport(self.server_transport)
server._inbound_streams_max = 2048
server._outbound_streams_count = 256
self.assertTrue(server.is_server)
# connect
run(server.start(client.getCapabilities(), client.port))
run(client.start(server.getCapabilities(), server.port))
# check outcome
run(wait_for_outcome(client, server))
self.assertEqual(client._association_state, RTCSctpTransport.State.ESTABLISHED)
self.assertEqual(client._inbound_streams_count, 256)
self.assertEqual(client._outbound_streams_count, 2048)
self.assertEqual(client._remote_extensions, [192, 130])
self.assertEqual(server._association_state, RTCSctpTransport.State.ESTABLISHED)
self.assertEqual(server._inbound_streams_count, 2048)
def test_receive_heartbeat(self):
ack = None
async def mock_send_chunk(chunk):
nonlocal ack
ack = chunk
client = RTCSctpTransport(self.client_transport)
client._last_received_tsn = 0
client._remote_port = 5000
client._send_chunk = mock_send_chunk
# receive heartbeat
chunk = HeartbeatChunk()
chunk.params.append((1, b"\x01\x02\x03\x04"))
chunk.tsn = 1
run(client._receive_chunk(chunk))
# check response
self.assertTrue(isinstance(ack, HeartbeatAckChunk))
self.assertEqual(ack.params, [(1, b"\x01\x02\x03\x04")])
def test_receive_forward_tsn(self):
received = []
async def fake_receive(*args):
received.append(args)
client = RTCSctpTransport(self.client_transport)
client._last_received_tsn = 101
client._receive = fake_receive
factory = ChunkFactory(tsn=102)
chunks = (
factory.create([b"foo"])
+ factory.create([b"baz"])
+ factory.create([b"qux"])
+ factory.create([b"quux"])
+ factory.create([b"corge"])
+ factory.create([b"grault"])
)
# receive chunks with gaps
for i in [0, 2, 3, 5]:
run(client._receive_chunk(chunks[i]))