You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			203 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			203 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
"""
 | 
						|
_handshake.py
 | 
						|
websocket - WebSocket client library for Python
 | 
						|
 | 
						|
Copyright 2024 engn33r
 | 
						|
 | 
						|
Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
you may not use this file except in compliance with the License.
 | 
						|
You may obtain a copy of the License at
 | 
						|
 | 
						|
    http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 | 
						|
Unless required by applicable law or agreed to in writing, software
 | 
						|
distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
See the License for the specific language governing permissions and
 | 
						|
limitations under the License.
 | 
						|
"""
 | 
						|
import hashlib
 | 
						|
import hmac
 | 
						|
import os
 | 
						|
from base64 import encodebytes as base64encode
 | 
						|
from http import HTTPStatus
 | 
						|
 | 
						|
from ._cookiejar import SimpleCookieJar
 | 
						|
from ._exceptions import WebSocketException, WebSocketBadStatusException
 | 
						|
from ._http import read_headers
 | 
						|
from ._logging import dump, error
 | 
						|
from ._socket import send
 | 
						|
 | 
						|
__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
 | 
						|
 | 
						|
# websocket supported version.
 | 
						|
VERSION = 13
 | 
						|
 | 
						|
SUPPORTED_REDIRECT_STATUSES = (
 | 
						|
    HTTPStatus.MOVED_PERMANENTLY,
 | 
						|
    HTTPStatus.FOUND,
 | 
						|
    HTTPStatus.SEE_OTHER,
 | 
						|
    HTTPStatus.TEMPORARY_REDIRECT,
 | 
						|
    HTTPStatus.PERMANENT_REDIRECT,
 | 
						|
)
 | 
						|
SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,)
 | 
						|
 | 
						|
CookieJar = SimpleCookieJar()
 | 
						|
 | 
						|
 | 
						|
class handshake_response:
 | 
						|
    def __init__(self, status: int, headers: dict, subprotocol):
 | 
						|
        self.status = status
 | 
						|
        self.headers = headers
 | 
						|
        self.subprotocol = subprotocol
 | 
						|
        CookieJar.add(headers.get("set-cookie"))
 | 
						|
 | 
						|
 | 
						|
def handshake(
 | 
						|
    sock, url: str, hostname: str, port: int, resource: str, **options
 | 
						|
) -> handshake_response:
 | 
						|
    headers, key = _get_handshake_headers(resource, url, hostname, port, options)
 | 
						|
 | 
						|
    header_str = "\r\n".join(headers)
 | 
						|
    send(sock, header_str)
 | 
						|
    dump("request header", header_str)
 | 
						|
 | 
						|
    status, resp = _get_resp_headers(sock)
 | 
						|
    if status in SUPPORTED_REDIRECT_STATUSES:
 | 
						|
        return handshake_response(status, resp, None)
 | 
						|
    success, subproto = _validate(resp, key, options.get("subprotocols"))
 | 
						|
    if not success:
 | 
						|
        raise WebSocketException("Invalid WebSocket Header")
 | 
						|
 | 
						|
    return handshake_response(status, resp, subproto)
 | 
						|
 | 
						|
 | 
						|
def _pack_hostname(hostname: str) -> str:
 | 
						|
    # IPv6 address
 | 
						|
    if ":" in hostname:
 | 
						|
        return f"[{hostname}]"
 | 
						|
    return hostname
 | 
						|
 | 
						|
 | 
						|
def _get_handshake_headers(
 | 
						|
    resource: str, url: str, host: str, port: int, options: dict
 | 
						|
) -> tuple:
 | 
						|
    headers = [f"GET {resource} HTTP/1.1", "Upgrade: websocket"]
 | 
						|
    if port in [80, 443]:
 | 
						|
        hostport = _pack_hostname(host)
 | 
						|
    else:
 | 
						|
        hostport = f"{_pack_hostname(host)}:{port}"
 | 
						|
    if options.get("host"):
 | 
						|
        headers.append(f'Host: {options["host"]}')
 | 
						|
    else:
 | 
						|
        headers.append(f"Host: {hostport}")
 | 
						|
 | 
						|
    # scheme indicates whether http or https is used in Origin
 | 
						|
    # The same approach is used in parse_url of _url.py to set default port
 | 
						|
    scheme, url = url.split(":", 1)
 | 
						|
    if not options.get("suppress_origin"):
 | 
						|
        if "origin" in options and options["origin"] is not None:
 | 
						|
            headers.append(f'Origin: {options["origin"]}')
 | 
						|
        elif scheme == "wss":
 | 
						|
            headers.append(f"Origin: https://{hostport}")
 | 
						|
        else:
 | 
						|
            headers.append(f"Origin: http://{hostport}")
 | 
						|
 | 
						|
    key = _create_sec_websocket_key()
 | 
						|
 | 
						|
    # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified
 | 
						|
    if not options.get("header") or "Sec-WebSocket-Key" not in options["header"]:
 | 
						|
        headers.append(f"Sec-WebSocket-Key: {key}")
 | 
						|
    else:
 | 
						|
        key = options["header"]["Sec-WebSocket-Key"]
 | 
						|
 | 
						|
    if not options.get("header") or "Sec-WebSocket-Version" not in options["header"]:
 | 
						|
        headers.append(f"Sec-WebSocket-Version: {VERSION}")
 | 
						|
 | 
						|
    if not options.get("connection"):
 | 
						|
        headers.append("Connection: Upgrade")
 | 
						|
    else:
 | 
						|
        headers.append(options["connection"])
 | 
						|
 | 
						|
    if subprotocols := options.get("subprotocols"):
 | 
						|
        headers.append(f'Sec-WebSocket-Protocol: {",".join(subprotocols)}')
 | 
						|
 | 
						|
    if header := options.get("header"):
 | 
						|
        if isinstance(header, dict):
 | 
						|
            header = [": ".join([k, v]) for k, v in header.items() if v is not None]
 | 
						|
        headers.extend(header)
 | 
						|
 | 
						|
    server_cookie = CookieJar.get(host)
 | 
						|
    client_cookie = options.get("cookie", None)
 | 
						|
 | 
						|
    if cookie := "; ".join(filter(None, [server_cookie, client_cookie])):
 | 
						|
        headers.append(f"Cookie: {cookie}")
 | 
						|
 | 
						|
    headers.extend(("", ""))
 | 
						|
    return headers, key
 | 
						|
 | 
						|
 | 
						|
def _get_resp_headers(sock, success_statuses: tuple = SUCCESS_STATUSES) -> tuple:
 | 
						|
    status, resp_headers, status_message = read_headers(sock)
 | 
						|
    if status not in success_statuses:
 | 
						|
        content_len = resp_headers.get("content-length")
 | 
						|
        if content_len:
 | 
						|
            response_body = sock.recv(
 | 
						|
                int(content_len)
 | 
						|
            )  # read the body of the HTTP error message response and include it in the exception
 | 
						|
        else:
 | 
						|
            response_body = None
 | 
						|
        raise WebSocketBadStatusException(
 | 
						|
            f"Handshake status {status} {status_message} -+-+- {resp_headers} -+-+- {response_body}",
 | 
						|
            status,
 | 
						|
            status_message,
 | 
						|
            resp_headers,
 | 
						|
            response_body,
 | 
						|
        )
 | 
						|
    return status, resp_headers
 | 
						|
 | 
						|
 | 
						|
_HEADERS_TO_CHECK = {
 | 
						|
    "upgrade": "websocket",
 | 
						|
    "connection": "upgrade",
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
def _validate(headers, key: str, subprotocols) -> tuple:
 | 
						|
    subproto = None
 | 
						|
    for k, v in _HEADERS_TO_CHECK.items():
 | 
						|
        r = headers.get(k, None)
 | 
						|
        if not r:
 | 
						|
            return False, None
 | 
						|
        r = [x.strip().lower() for x in r.split(",")]
 | 
						|
        if v not in r:
 | 
						|
            return False, None
 | 
						|
 | 
						|
    if subprotocols:
 | 
						|
        subproto = headers.get("sec-websocket-protocol", None)
 | 
						|
        if not subproto or subproto.lower() not in [s.lower() for s in subprotocols]:
 | 
						|
            error(f"Invalid subprotocol: {subprotocols}")
 | 
						|
            return False, None
 | 
						|
        subproto = subproto.lower()
 | 
						|
 | 
						|
    result = headers.get("sec-websocket-accept", None)
 | 
						|
    if not result:
 | 
						|
        return False, None
 | 
						|
    result = result.lower()
 | 
						|
 | 
						|
    if isinstance(result, str):
 | 
						|
        result = result.encode("utf-8")
 | 
						|
 | 
						|
    value = f"{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11".encode("utf-8")
 | 
						|
    hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
 | 
						|
 | 
						|
    if hmac.compare_digest(hashed, result):
 | 
						|
        return True, subproto
 | 
						|
    else:
 | 
						|
        return False, None
 | 
						|
 | 
						|
 | 
						|
def _create_sec_websocket_key() -> str:
 | 
						|
    randomness = os.urandom(16)
 | 
						|
    return base64encode(randomness).decode("utf-8").strip()
 |