You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			211 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			211 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Python
		
	
import errno
 | 
						|
import signal
 | 
						|
import socket
 | 
						|
from subprocess import Popen
 | 
						|
import sys
 | 
						|
import time
 | 
						|
import unittest
 | 
						|
 | 
						|
from tornado.netutil import (
 | 
						|
    BlockingResolver,
 | 
						|
    OverrideResolver,
 | 
						|
    ThreadedResolver,
 | 
						|
    is_valid_ip,
 | 
						|
    bind_sockets,
 | 
						|
)
 | 
						|
from tornado.testing import AsyncTestCase, gen_test, bind_unused_port
 | 
						|
from tornado.test.util import skipIfNoNetwork, abstract_base_test
 | 
						|
 | 
						|
import typing
 | 
						|
 | 
						|
try:
 | 
						|
    import pycares  # type: ignore
 | 
						|
except ImportError:
 | 
						|
    pycares = None
 | 
						|
else:
 | 
						|
    from tornado.platform.caresresolver import CaresResolver
 | 
						|
 | 
						|
 | 
						|
@abstract_base_test
 | 
						|
class _ResolverTestMixin(AsyncTestCase):
 | 
						|
    resolver = None  # type: typing.Any
 | 
						|
 | 
						|
    @gen_test
 | 
						|
    def test_localhost(self):
 | 
						|
        addrinfo = yield self.resolver.resolve("localhost", 80, socket.AF_UNSPEC)
 | 
						|
        # Most of the time localhost resolves to either the ipv4 loopback
 | 
						|
        # address alone, or ipv4+ipv6. But some versions of pycares will only
 | 
						|
        # return the ipv6 version, so we have to check for either one alone.
 | 
						|
        self.assertTrue(
 | 
						|
            ((socket.AF_INET, ("127.0.0.1", 80)) in addrinfo)
 | 
						|
            or ((socket.AF_INET6, ("::1", 80)) in addrinfo),
 | 
						|
            f"loopback address not found in {addrinfo}",
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
# It is impossible to quickly and consistently generate an error in name
 | 
						|
# resolution, so test this case separately, using mocks as needed.
 | 
						|
@abstract_base_test
 | 
						|
class _ResolverErrorTestMixin(AsyncTestCase):
 | 
						|
    resolver = None  # type: typing.Any
 | 
						|
 | 
						|
    @gen_test
 | 
						|
    def test_bad_host(self):
 | 
						|
        with self.assertRaises(IOError):
 | 
						|
            yield self.resolver.resolve("an invalid domain", 80, socket.AF_UNSPEC)
 | 
						|
 | 
						|
 | 
						|
def _failing_getaddrinfo(*args):
 | 
						|
    """Dummy implementation of getaddrinfo for use in mocks"""
 | 
						|
    raise socket.gaierror(errno.EIO, "mock: lookup failed")
 | 
						|
 | 
						|
 | 
						|
@skipIfNoNetwork
 | 
						|
class BlockingResolverTest(_ResolverTestMixin):
 | 
						|
    def setUp(self):
 | 
						|
        super().setUp()
 | 
						|
        self.resolver = BlockingResolver()
 | 
						|
 | 
						|
 | 
						|
# getaddrinfo-based tests need mocking to reliably generate errors;
 | 
						|
# some configurations are slow to produce errors and take longer than
 | 
						|
# our default timeout.
 | 
						|
class BlockingResolverErrorTest(_ResolverErrorTestMixin):
 | 
						|
    def setUp(self):
 | 
						|
        super().setUp()
 | 
						|
        self.resolver = BlockingResolver()
 | 
						|
        self.real_getaddrinfo = socket.getaddrinfo
 | 
						|
        socket.getaddrinfo = _failing_getaddrinfo
 | 
						|
 | 
						|
    def tearDown(self):
 | 
						|
        socket.getaddrinfo = self.real_getaddrinfo
 | 
						|
        super().tearDown()
 | 
						|
 | 
						|
 | 
						|
class OverrideResolverTest(_ResolverTestMixin):
 | 
						|
    def setUp(self):
 | 
						|
        super().setUp()
 | 
						|
        mapping = {
 | 
						|
            ("google.com", 80): ("1.2.3.4", 80),
 | 
						|
            ("google.com", 80, socket.AF_INET): ("1.2.3.4", 80),
 | 
						|
            ("google.com", 80, socket.AF_INET6): (
 | 
						|
                "2a02:6b8:7c:40c:c51e:495f:e23a:3",
 | 
						|
                80,
 | 
						|
            ),
 | 
						|
        }
 | 
						|
        self.resolver = OverrideResolver(BlockingResolver(), mapping)
 | 
						|
 | 
						|
    @gen_test
 | 
						|
    def test_resolve_multiaddr(self):
 | 
						|
        result = yield self.resolver.resolve("google.com", 80, socket.AF_INET)
 | 
						|
        self.assertIn((socket.AF_INET, ("1.2.3.4", 80)), result)
 | 
						|
 | 
						|
        result = yield self.resolver.resolve("google.com", 80, socket.AF_INET6)
 | 
						|
        self.assertIn(
 | 
						|
            (socket.AF_INET6, ("2a02:6b8:7c:40c:c51e:495f:e23a:3", 80, 0, 0)), result
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
@skipIfNoNetwork
 | 
						|
class ThreadedResolverTest(_ResolverTestMixin):
 | 
						|
    def setUp(self):
 | 
						|
        super().setUp()
 | 
						|
        self.resolver = ThreadedResolver()
 | 
						|
 | 
						|
    def tearDown(self):
 | 
						|
        self.resolver.close()
 | 
						|
        super().tearDown()
 | 
						|
 | 
						|
 | 
						|
class ThreadedResolverErrorTest(_ResolverErrorTestMixin):
 | 
						|
    def setUp(self):
 | 
						|
        super().setUp()
 | 
						|
        self.resolver = BlockingResolver()
 | 
						|
        self.real_getaddrinfo = socket.getaddrinfo
 | 
						|
        socket.getaddrinfo = _failing_getaddrinfo
 | 
						|
 | 
						|
    def tearDown(self):
 | 
						|
        socket.getaddrinfo = self.real_getaddrinfo
 | 
						|
        super().tearDown()
 | 
						|
 | 
						|
 | 
						|
@skipIfNoNetwork
 | 
						|
@unittest.skipIf(sys.platform == "win32", "preexec_fn not available on win32")
 | 
						|
class ThreadedResolverImportTest(unittest.TestCase):
 | 
						|
    def test_import(self):
 | 
						|
        TIMEOUT = 5
 | 
						|
 | 
						|
        # Test for a deadlock when importing a module that runs the
 | 
						|
        # ThreadedResolver at import-time. See resolve_test.py for
 | 
						|
        # full explanation.
 | 
						|
        command = [sys.executable, "-c", "import tornado.test.resolve_test_helper"]
 | 
						|
 | 
						|
        start = time.time()
 | 
						|
        popen = Popen(command, preexec_fn=lambda: signal.alarm(TIMEOUT))
 | 
						|
        while time.time() - start < TIMEOUT:
 | 
						|
            return_code = popen.poll()
 | 
						|
            if return_code is not None:
 | 
						|
                self.assertEqual(0, return_code)
 | 
						|
                return  # Success.
 | 
						|
            time.sleep(0.05)
 | 
						|
 | 
						|
        self.fail("import timed out")
 | 
						|
 | 
						|
 | 
						|
# We do not test errors with CaresResolver:
 | 
						|
# Some DNS-hijacking ISPs (e.g. Time Warner) return non-empty results
 | 
						|
# with an NXDOMAIN status code.  Most resolvers treat this as an error;
 | 
						|
# C-ares returns the results, making the "bad_host" tests unreliable.
 | 
						|
# C-ares will try to resolve even malformed names, such as the
 | 
						|
# name with spaces used in this test.
 | 
						|
@skipIfNoNetwork
 | 
						|
@unittest.skipIf(pycares is None, "pycares module not present")
 | 
						|
@unittest.skipIf(sys.platform == "win32", "pycares doesn't return loopback on windows")
 | 
						|
@unittest.skipIf(sys.platform == "darwin", "pycares doesn't return 127.0.0.1 on darwin")
 | 
						|
class CaresResolverTest(_ResolverTestMixin):
 | 
						|
    def setUp(self):
 | 
						|
        super().setUp()
 | 
						|
        self.resolver = CaresResolver()
 | 
						|
 | 
						|
 | 
						|
class IsValidIPTest(unittest.TestCase):
 | 
						|
    def test_is_valid_ip(self):
 | 
						|
        self.assertTrue(is_valid_ip("127.0.0.1"))
 | 
						|
        self.assertTrue(is_valid_ip("4.4.4.4"))
 | 
						|
        self.assertTrue(is_valid_ip("::1"))
 | 
						|
        self.assertTrue(is_valid_ip("2620:0:1cfe:face:b00c::3"))
 | 
						|
        self.assertFalse(is_valid_ip("www.google.com"))
 | 
						|
        self.assertFalse(is_valid_ip("localhost"))
 | 
						|
        self.assertFalse(is_valid_ip("4.4.4.4<"))
 | 
						|
        self.assertFalse(is_valid_ip(" 127.0.0.1"))
 | 
						|
        self.assertFalse(is_valid_ip(""))
 | 
						|
        self.assertFalse(is_valid_ip(" "))
 | 
						|
        self.assertFalse(is_valid_ip("\n"))
 | 
						|
        self.assertFalse(is_valid_ip("\x00"))
 | 
						|
        self.assertFalse(is_valid_ip("a" * 100))
 | 
						|
 | 
						|
 | 
						|
class TestPortAllocation(unittest.TestCase):
 | 
						|
    def test_same_port_allocation(self):
 | 
						|
        sockets = bind_sockets(0, "localhost")
 | 
						|
        try:
 | 
						|
            port = sockets[0].getsockname()[1]
 | 
						|
            self.assertTrue(all(s.getsockname()[1] == port for s in sockets[1:]))
 | 
						|
        finally:
 | 
						|
            for sock in sockets:
 | 
						|
                sock.close()
 | 
						|
 | 
						|
    @unittest.skipIf(
 | 
						|
        not hasattr(socket, "SO_REUSEPORT"), "SO_REUSEPORT is not supported"
 | 
						|
    )
 | 
						|
    def test_reuse_port(self):
 | 
						|
        sockets: typing.List[socket.socket] = []
 | 
						|
        sock, port = bind_unused_port(reuse_port=True)
 | 
						|
        try:
 | 
						|
            sockets = bind_sockets(port, "127.0.0.1", reuse_port=True)
 | 
						|
            self.assertTrue(all(s.getsockname()[1] == port for s in sockets))
 | 
						|
        finally:
 | 
						|
            sock.close()
 | 
						|
            for sock in sockets:
 | 
						|
                sock.close()
 |