mirror of
https://git.yoctoproject.org/git/poky
synced 2026-01-01 13:58:04 +00:00
bitbake: hashserv: Add websocket connection implementation
Adds support to the hash equivalence client and server to communicate over websockets. Since websockets are message orientated instead of stream orientated, and new connection class is needed to handle them. Note that websocket support does require the 3rd party websockets python module be installed on the host, but it should not be required unless websockets are actually being used. (Bitbake rev: 56dd2fdbfb6350a9eef43a12aa529c8637887a7e) Signed-off-by: Joshua Watt <JPEWhacker@gmail.com> Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
This commit is contained in:
parent
8f8501ed40
commit
2484bd8931
|
|
@ -10,7 +10,7 @@ import json
|
|||
import os
|
||||
import socket
|
||||
import sys
|
||||
from .connection import StreamConnection, DEFAULT_MAX_CHUNK
|
||||
from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
|
||||
from .exceptions import ConnectionClosedError
|
||||
|
||||
|
||||
|
|
@ -47,6 +47,15 @@ class AsyncClient(object):
|
|||
|
||||
self._connect_sock = connect_sock
|
||||
|
||||
async def connect_websocket(self, uri):
|
||||
import websockets
|
||||
|
||||
async def connect_sock():
|
||||
websocket = await websockets.connect(uri, ping_interval=None)
|
||||
return WebsocketConnection(websocket, self.timeout)
|
||||
|
||||
self._connect_sock = connect_sock
|
||||
|
||||
async def setup_connection(self):
|
||||
# Send headers
|
||||
await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
|
||||
|
|
|
|||
|
|
@ -93,3 +93,47 @@ class StreamConnection(object):
|
|||
if self.writer is not None:
|
||||
self.writer.close()
|
||||
self.writer = None
|
||||
|
||||
|
||||
class WebsocketConnection(object):
|
||||
def __init__(self, socket, timeout):
|
||||
self.socket = socket
|
||||
self.timeout = timeout
|
||||
|
||||
@property
|
||||
def address(self):
|
||||
return ":".join(str(s) for s in self.socket.remote_address)
|
||||
|
||||
async def send_message(self, msg):
|
||||
await self.send(json.dumps(msg))
|
||||
|
||||
async def recv_message(self):
|
||||
m = await self.recv()
|
||||
return json.loads(m)
|
||||
|
||||
async def send(self, msg):
|
||||
import websockets.exceptions
|
||||
|
||||
try:
|
||||
await self.socket.send(msg)
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
raise ConnectionClosedError("Connection closed")
|
||||
|
||||
async def recv(self):
|
||||
import websockets.exceptions
|
||||
|
||||
try:
|
||||
if self.timeout < 0:
|
||||
return await self.socket.recv()
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(self.socket.recv(), self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise ConnectionError("Timed out waiting for data")
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
raise ConnectionClosedError("Connection closed")
|
||||
|
||||
async def close(self):
|
||||
if self.socket is not None:
|
||||
await self.socket.close()
|
||||
self.socket = None
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import signal
|
|||
import socket
|
||||
import sys
|
||||
import multiprocessing
|
||||
from .connection import StreamConnection
|
||||
from .connection import StreamConnection, WebsocketConnection
|
||||
from .exceptions import ClientError, ServerError, ConnectionClosedError
|
||||
|
||||
|
||||
|
|
@ -178,6 +178,54 @@ class UnixStreamServer(StreamServer):
|
|||
os.unlink(self.path)
|
||||
|
||||
|
||||
class WebsocketsServer(object):
|
||||
def __init__(self, host, port, handler, logger):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.handler = handler
|
||||
self.logger = logger
|
||||
|
||||
def start(self, loop):
|
||||
import websockets.server
|
||||
|
||||
self.server = loop.run_until_complete(
|
||||
websockets.server.serve(
|
||||
self.client_handler,
|
||||
self.host,
|
||||
self.port,
|
||||
ping_interval=None,
|
||||
)
|
||||
)
|
||||
|
||||
for s in self.server.sockets:
|
||||
self.logger.debug("Listening on %r" % (s.getsockname(),))
|
||||
|
||||
# Enable keep alives. This prevents broken client connections
|
||||
# from persisting on the server for long periods of time.
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||||
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
|
||||
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
|
||||
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
|
||||
|
||||
name = self.server.sockets[0].getsockname()
|
||||
if self.server.sockets[0].family == socket.AF_INET6:
|
||||
self.address = "ws://[%s]:%d" % (name[0], name[1])
|
||||
else:
|
||||
self.address = "ws://%s:%d" % (name[0], name[1])
|
||||
|
||||
return [self.server.wait_closed()]
|
||||
|
||||
async def stop(self):
|
||||
self.server.close()
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
async def client_handler(self, websocket):
|
||||
socket = WebsocketConnection(websocket, -1)
|
||||
await self.handler(socket)
|
||||
|
||||
|
||||
class AsyncServer(object):
|
||||
def __init__(self, logger):
|
||||
self.logger = logger
|
||||
|
|
@ -190,6 +238,9 @@ class AsyncServer(object):
|
|||
def start_unix_server(self, path):
|
||||
self.server = UnixStreamServer(path, self._client_handler, self.logger)
|
||||
|
||||
def start_websocket_server(self, host, port):
|
||||
self.server = WebsocketsServer(host, port, self._client_handler, self.logger)
|
||||
|
||||
async def _client_handler(self, socket):
|
||||
try:
|
||||
client = self.accept_client(socket)
|
||||
|
|
|
|||
|
|
@ -9,11 +9,15 @@ import re
|
|||
import sqlite3
|
||||
import itertools
|
||||
import json
|
||||
from urllib.parse import urlparse
|
||||
|
||||
UNIX_PREFIX = "unix://"
|
||||
WS_PREFIX = "ws://"
|
||||
WSS_PREFIX = "wss://"
|
||||
|
||||
ADDR_TYPE_UNIX = 0
|
||||
ADDR_TYPE_TCP = 1
|
||||
ADDR_TYPE_WS = 2
|
||||
|
||||
UNIHASH_TABLE_DEFINITION = (
|
||||
("method", "TEXT NOT NULL", "UNIQUE"),
|
||||
|
|
@ -84,6 +88,8 @@ def setup_database(database, sync=True):
|
|||
def parse_address(addr):
|
||||
if addr.startswith(UNIX_PREFIX):
|
||||
return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
|
||||
elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
|
||||
return (ADDR_TYPE_WS, (addr,))
|
||||
else:
|
||||
m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
|
||||
if m is not None:
|
||||
|
|
@ -103,6 +109,9 @@ def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
|
|||
(typ, a) = parse_address(addr)
|
||||
if typ == ADDR_TYPE_UNIX:
|
||||
s.start_unix_server(*a)
|
||||
elif typ == ADDR_TYPE_WS:
|
||||
url = urlparse(a[0])
|
||||
s.start_websocket_server(url.hostname, url.port)
|
||||
else:
|
||||
s.start_tcp_server(*a)
|
||||
|
||||
|
|
@ -116,6 +125,8 @@ def create_client(addr):
|
|||
(typ, a) = parse_address(addr)
|
||||
if typ == ADDR_TYPE_UNIX:
|
||||
c.connect_unix(*a)
|
||||
elif typ == ADDR_TYPE_WS:
|
||||
c.connect_websocket(*a)
|
||||
else:
|
||||
c.connect_tcp(*a)
|
||||
|
||||
|
|
@ -128,6 +139,8 @@ async def create_async_client(addr):
|
|||
(typ, a) = parse_address(addr)
|
||||
if typ == ADDR_TYPE_UNIX:
|
||||
await c.connect_unix(*a)
|
||||
elif typ == ADDR_TYPE_WS:
|
||||
await c.connect_websocket(*a)
|
||||
else:
|
||||
await c.connect_tcp(*a)
|
||||
|
||||
|
|
|
|||
|
|
@ -115,6 +115,7 @@ class Client(bb.asyncrpc.Client):
|
|||
super().__init__()
|
||||
self._add_methods(
|
||||
"connect_tcp",
|
||||
"connect_websocket",
|
||||
"get_unihash",
|
||||
"report_unihash",
|
||||
"report_unihash_equiv",
|
||||
|
|
|
|||
|
|
@ -483,3 +483,20 @@ class TestHashEquivalenceTCPServer(HashEquivalenceTestSetup, HashEquivalenceComm
|
|||
# If IPv6 is enabled, it should be safe to use localhost directly, in general
|
||||
# case it is more reliable to resolve the IP address explicitly.
|
||||
return socket.gethostbyname("localhost") + ":0"
|
||||
|
||||
|
||||
class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
|
||||
def setUp(self):
|
||||
try:
|
||||
import websockets
|
||||
except ImportError as e:
|
||||
self.skipTest(str(e))
|
||||
|
||||
super().setUp()
|
||||
|
||||
def get_server_addr(self, server_idx):
|
||||
# Some hosts cause asyncio module to misbehave, when IPv6 is not enabled.
|
||||
# If IPv6 is enabled, it should be safe to use localhost directly, in general
|
||||
# case it is more reliable to resolve the IP address explicitly.
|
||||
host = socket.gethostbyname("localhost")
|
||||
return "ws://%s:0" % host
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user