Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
async with await stream_server.accept() as stream:
assert stream.tls_version is None
while True:
command = await stream.receive_exactly(5)
if command == b'START':
await stream.start_tls()
assert stream.tls_version.startswith('TLSv')
elif command == b'CLOSE':
break
async with await create_tcp_server(ssl_context=server_context,
autostart_tls=False) as stream_server:
async with create_task_group() as tg:
await tg.spawn(server)
async with await connect_tcp('localhost', stream_server.port,
ssl_context=client_context) as client:
assert client.tls_version is None
await client.send(b'START') # arbitrary string
await client.start_tls()
assert client.tls_version.startswith('TLSv')
await client.send(b'CLOSE') # arbitrary string
async def test_happy_eyeballs(self, interface, expected_addr, fake_localhost_dns):
async def handle_client(stream):
addr, port, *rest = stream.peer_address
await stream.send(addr.encode() + b'\n')
async def server():
async for stream in stream_server.accept_connections():
await tg.spawn(handle_client, stream)
async with await create_tcp_server(interface=interface) as stream_server:
async with create_task_group() as tg:
await tg.spawn(server)
async with await connect_tcp('localhost', stream_server.port) as client:
assert await client.receive_until(b'\n', 100) == expected_addr
assert client.address[0] == expected_addr.decode()
await stream_server.aclose()
async def test_ragged_eofs(self, server_context, client_context, server_compatible,
client_compatible, exc_class):
async def server():
async with await stream_server.accept() as stream:
chunks.append(await stream.receive_exactly(2))
await stream.send(b'OK\n')
with pytest.raises(exc_class):
await stream.receive_exactly(2)
chunks = []
async with await create_tcp_server(
ssl_context=server_context,
tls_standard_compatible=server_compatible) as stream_server:
async with create_task_group() as tg:
await tg.spawn(server)
async with await connect_tcp(
'localhost', stream_server.port, ssl_context=client_context,
autostart_tls=True, tls_standard_compatible=client_compatible) as client:
await client.send(b'bl')
assert await client.receive_exactly(3) == b'OK\n'
assert chunks == [b'bl']
async def test_receive_from_cache(self, localhost):
async def server():
async with await stream_server.accept() as stream:
await stream.receive_until(b'a', 10)
request = await stream.receive(1)
await stream.send(request + b'\n')
async with create_task_group() as tg:
async with await create_tcp_server(interface=localhost) as stream_server:
await tg.spawn(server)
async with await connect_tcp(localhost, stream_server.port) as client:
await client.send(b'abc')
received = await client.receive_until(b'\n', 3)
assert received == b'b'
async def test_alpn_negotiation(self, server_context, client_context):
async def server():
async with await stream_server.accept() as stream:
assert stream.alpn_protocol == 'dummy2'
client_context.set_alpn_protocols(['dummy1', 'dummy2'])
server_context.set_alpn_protocols(['dummy2', 'dummy3'])
async with await create_tcp_server(ssl_context=server_context) as stream_server:
async with create_task_group() as tg:
await tg.spawn(server)
async with await connect_tcp(
'localhost', stream_server.port, ssl_context=client_context,
autostart_tls=True) as client:
assert client.alpn_protocol == 'dummy2'
host, port, resource, use_ssl = _url_to_host(url, use_ssl)
if use_ssl is True:
ssl_context = ssl.create_default_context()
elif use_ssl is False:
ssl_context = None
elif isinstance(use_ssl, ssl.SSLContext):
ssl_context = use_ssl
else:
raise TypeError('use_ssl argument must be bool or ssl.SSLContext')
logger.info('Connecting to %s...', url)
tls = True if ssl_context else False
stream = await anyio.connect_tcp(
host, int(port), ssl_context=ssl_context, autostart_tls=tls, tls_standard_compatible=False)
if port in (80, 443):
host_header = host
else:
host_header = '{}:{}'.format(host, port)
wsproto = WSConnection(ConnectionType.CLIENT)
connection = WebSocketConnection(
stream, wsproto, host=host_header, path=resource, subprotocols=subprotocols,
headers=headers, message_queue_size=message_queue_size, max_message_size=max_message_size
)
await task_group.spawn(connection._reader_task)
await connection._open_handshake.wait()
return connection
async def __aenter__(self):
await super().__aenter__() # Does nothing
socket = await anyio.connect_tcp(self._host, self._port, autostart_tls=False, tls_standard_compatible=False)
config = GRPCConfiguration(client_side=True)
self._grpc_socket = await self.enter_async_context(GRPCProtoSocket(config, socket))
return self
async def _open_connection_https(self, location):
"""
Creates an async SSL socket, returns it.
Args:
location (tuple(str, int)): A tuple of net location (eg
'127.0.0.1' or 'example.org') and port (eg 80 or 25000).
"""
sock = await connect_tcp(
location[0],
location[1],
ssl_context=self.ssl_context,
bind_host=self.source_address,
autostart_tls=True,
tls_standard_compatible=False,
)
sock._active = True
return sock