diff --git a/docker_agent_sandbox/sandbox.py b/docker_agent_sandbox/sandbox.py index 009a699..1c8e34a 100644 --- a/docker_agent_sandbox/sandbox.py +++ b/docker_agent_sandbox/sandbox.py @@ -6,11 +6,14 @@ import io import select as _select import socket as _socket import tarfile +from contextlib import ExitStack from select import select as _original_select from pathlib import Path from typing import TYPE_CHECKING from unittest.mock import patch +_original_poll = getattr(_select, "poll", None) + import docker import docker.errors from docker.utils.socket import consume_socket_output, demux_adaptor, frames_iter @@ -46,6 +49,7 @@ class DockerSandbox: security_opt: list[str] | None = None, cpu_limit: float = 8, memory_limit: str = "16g", + command: str | None = None, ) -> None: self.container_name = container_name self._image = image @@ -60,6 +64,7 @@ class DockerSandbox: self._security_opt = security_opt self._nano_cpus = int(cpu_limit * 1e9) self._memory_limit = memory_limit + self._command = command self._client: docker.DockerClient = docker.from_env() self._container: docker.models.containers.Container | None = None @@ -132,7 +137,7 @@ class DockerSandbox: run_kwargs["mem_limit"] = self._memory_limit try: - self._container = self._client.containers.run(self._image, **run_kwargs) + self._container = self._client.containers.run(self._image, self._command, **run_kwargs) except docker.errors.ImageNotFound: raise RuntimeError( f"Image {self._image!r} not found locally. " @@ -258,9 +263,12 @@ class DockerSandbox: # TODO(fragile): timeout enforcement relies on private docker-py internals # (frames_iter, demux_adaptor, consume_socket_output from docker.utils.socket) - # and monkey-patches select.select for the duration of the read — not thread-safe - # if multiple exec() calls run concurrently. Replace when docker-py adds native - # per-call timeout support. See https://github.com/docker/docker-py/issues/2651 + # and monkey-patches select.select / select.poll for the duration of the read + # — not thread-safe if multiple exec() calls run concurrently. Replace when + # docker-py adds native per-call timeout support. + # See https://github.com/docker/docker-py/issues/2651 + # + # On Linux docker-py uses select.poll (not select.select), so both are patched. try: exec_id = self._client.api.exec_create( self._container.id, @@ -271,17 +279,39 @@ class DockerSandbox: ) sock = self._client.api.exec_start(exec_id["Id"], socket=True) sock._sock.settimeout(timeout) - with patch.object( - _select, - "select", - new=lambda rlist, wlist, xlist: _original_select( - rlist, wlist, xlist, timeout - ), - ): + + timeout_ms = timeout * 1000 + + class _PollWithTimeout: + def __init__(self): + self._inner = _original_poll() + + def register(self, fd, eventmask): + return self._inner.register(fd, eventmask) + + def poll(self, *args): + result = self._inner.poll(timeout_ms) + if not result: + raise _socket.timeout(f"timed out after {timeout}s") + return result + + with ExitStack() as stack: + stack.enter_context(patch.object( + _select, "select", + new=lambda rlist, wlist, xlist: _original_select( + rlist, wlist, xlist, timeout + ), + )) + if _original_poll is not None: + stack.enter_context( + patch.object(_select, "poll", new=_PollWithTimeout) + ) gen = (demux_adaptor(*frame) for frame in frames_iter(sock, tty=False)) stdout, stderr = consume_socket_output(gen, demux=True) - exit_code = self._client.api.exec_inspect(exec_id["Id"])["ExitCode"] or 0 + exit_code = self._client.api.exec_inspect(exec_id["Id"])["ExitCode"] + if exit_code is None: + exit_code = 0 output = (stdout or b"") + (stderr or b"") return exit_code, output.decode("utf-8", errors="replace") except _socket.timeout: