fix: Timeout monkey patch not working on Linux. Implement native poll
override
This commit is contained in:
@@ -6,11 +6,14 @@ import io
|
|||||||
import select as _select
|
import select as _select
|
||||||
import socket as _socket
|
import socket as _socket
|
||||||
import tarfile
|
import tarfile
|
||||||
|
from contextlib import ExitStack
|
||||||
from select import select as _original_select
|
from select import select as _original_select
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
_original_poll = getattr(_select, "poll", None)
|
||||||
|
|
||||||
import docker
|
import docker
|
||||||
import docker.errors
|
import docker.errors
|
||||||
from docker.utils.socket import consume_socket_output, demux_adaptor, frames_iter
|
from docker.utils.socket import consume_socket_output, demux_adaptor, frames_iter
|
||||||
@@ -46,6 +49,7 @@ class DockerSandbox:
|
|||||||
security_opt: list[str] | None = None,
|
security_opt: list[str] | None = None,
|
||||||
cpu_limit: float = 8,
|
cpu_limit: float = 8,
|
||||||
memory_limit: str = "16g",
|
memory_limit: str = "16g",
|
||||||
|
command: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.container_name = container_name
|
self.container_name = container_name
|
||||||
self._image = image
|
self._image = image
|
||||||
@@ -60,6 +64,7 @@ class DockerSandbox:
|
|||||||
self._security_opt = security_opt
|
self._security_opt = security_opt
|
||||||
self._nano_cpus = int(cpu_limit * 1e9)
|
self._nano_cpus = int(cpu_limit * 1e9)
|
||||||
self._memory_limit = memory_limit
|
self._memory_limit = memory_limit
|
||||||
|
self._command = command
|
||||||
self._client: docker.DockerClient = docker.from_env()
|
self._client: docker.DockerClient = docker.from_env()
|
||||||
self._container: docker.models.containers.Container | None = None
|
self._container: docker.models.containers.Container | None = None
|
||||||
|
|
||||||
@@ -132,7 +137,7 @@ class DockerSandbox:
|
|||||||
run_kwargs["mem_limit"] = self._memory_limit
|
run_kwargs["mem_limit"] = self._memory_limit
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._container = self._client.containers.run(self._image, **run_kwargs)
|
self._container = self._client.containers.run(self._image, self._command, **run_kwargs)
|
||||||
except docker.errors.ImageNotFound:
|
except docker.errors.ImageNotFound:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Image {self._image!r} not found locally. "
|
f"Image {self._image!r} not found locally. "
|
||||||
@@ -258,9 +263,12 @@ class DockerSandbox:
|
|||||||
|
|
||||||
# TODO(fragile): timeout enforcement relies on private docker-py internals
|
# TODO(fragile): timeout enforcement relies on private docker-py internals
|
||||||
# (frames_iter, demux_adaptor, consume_socket_output from docker.utils.socket)
|
# (frames_iter, demux_adaptor, consume_socket_output from docker.utils.socket)
|
||||||
# and monkey-patches select.select for the duration of the read — not thread-safe
|
# and monkey-patches select.select / select.poll for the duration of the read
|
||||||
# if multiple exec() calls run concurrently. Replace when docker-py adds native
|
# — not thread-safe if multiple exec() calls run concurrently. Replace when
|
||||||
# per-call timeout support. See https://github.com/docker/docker-py/issues/2651
|
# docker-py adds native per-call timeout support.
|
||||||
|
# See https://github.com/docker/docker-py/issues/2651
|
||||||
|
#
|
||||||
|
# On Linux docker-py uses select.poll (not select.select), so both are patched.
|
||||||
try:
|
try:
|
||||||
exec_id = self._client.api.exec_create(
|
exec_id = self._client.api.exec_create(
|
||||||
self._container.id,
|
self._container.id,
|
||||||
@@ -271,17 +279,39 @@ class DockerSandbox:
|
|||||||
)
|
)
|
||||||
sock = self._client.api.exec_start(exec_id["Id"], socket=True)
|
sock = self._client.api.exec_start(exec_id["Id"], socket=True)
|
||||||
sock._sock.settimeout(timeout)
|
sock._sock.settimeout(timeout)
|
||||||
with patch.object(
|
|
||||||
_select,
|
timeout_ms = timeout * 1000
|
||||||
"select",
|
|
||||||
|
class _PollWithTimeout:
|
||||||
|
def __init__(self):
|
||||||
|
self._inner = _original_poll()
|
||||||
|
|
||||||
|
def register(self, fd, eventmask):
|
||||||
|
return self._inner.register(fd, eventmask)
|
||||||
|
|
||||||
|
def poll(self, *args):
|
||||||
|
result = self._inner.poll(timeout_ms)
|
||||||
|
if not result:
|
||||||
|
raise _socket.timeout(f"timed out after {timeout}s")
|
||||||
|
return result
|
||||||
|
|
||||||
|
with ExitStack() as stack:
|
||||||
|
stack.enter_context(patch.object(
|
||||||
|
_select, "select",
|
||||||
new=lambda rlist, wlist, xlist: _original_select(
|
new=lambda rlist, wlist, xlist: _original_select(
|
||||||
rlist, wlist, xlist, timeout
|
rlist, wlist, xlist, timeout
|
||||||
),
|
),
|
||||||
):
|
))
|
||||||
|
if _original_poll is not None:
|
||||||
|
stack.enter_context(
|
||||||
|
patch.object(_select, "poll", new=_PollWithTimeout)
|
||||||
|
)
|
||||||
gen = (demux_adaptor(*frame) for frame in frames_iter(sock, tty=False))
|
gen = (demux_adaptor(*frame) for frame in frames_iter(sock, tty=False))
|
||||||
stdout, stderr = consume_socket_output(gen, demux=True)
|
stdout, stderr = consume_socket_output(gen, demux=True)
|
||||||
|
|
||||||
exit_code = self._client.api.exec_inspect(exec_id["Id"])["ExitCode"] or 0
|
exit_code = self._client.api.exec_inspect(exec_id["Id"])["ExitCode"]
|
||||||
|
if exit_code is None:
|
||||||
|
exit_code = 0
|
||||||
output = (stdout or b"") + (stderr or b"")
|
output = (stdout or b"") + (stderr or b"")
|
||||||
return exit_code, output.decode("utf-8", errors="replace")
|
return exit_code, output.decode("utf-8", errors="replace")
|
||||||
except _socket.timeout:
|
except _socket.timeout:
|
||||||
|
|||||||
Reference in New Issue
Block a user