Skip to content

Commit 801994a

Browse files
authored
Better HTTP detection in the agent (#3307)
* Improved detection logic * Docs * Fixed UT to check stream end * Clippy pls
1 parent 6b97972 commit 801994a

File tree

8 files changed

+311
-203
lines changed

8 files changed

+311
-203
lines changed

changelog.d/3296.changed.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improved HTTP detection logic in the mirrord agent.

mirrord/agent/src/http.rs

+133-24
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
use std::{io, ops::Not, time::Duration};
2+
3+
use bytes::BytesMut;
4+
use futures::future::OptionFuture;
5+
use httparse::Status;
6+
use tokio::{
7+
io::{AsyncRead, AsyncReadExt},
8+
time::Instant,
9+
};
10+
use tracing::Level;
11+
12+
use crate::util::rolledback_stream::RolledBackStream;
13+
114
/// Helper enum for representing HTTP/1.x and HTTP/2, which are handled very differently in some
215
/// parts of the code.
316
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
@@ -12,34 +25,130 @@ pub enum HttpVersion {
1225
impl HttpVersion {
1326
/// Default start of an HTTP/2 request.
1427
///
15-
/// Used in [`Self::new`] to check if the connection should be treated as HTTP/2.
16-
pub const H2_PREFACE: &'static [u8; 14] = b"PRI * HTTP/2.0";
28+
/// Used in [`Self::detect`] to check if the connection should be treated as HTTP/2.
29+
const H2_PREFACE: &'static [u8; 14] = b"PRI * HTTP/2.0";
1730

18-
/// Controls the amount of data we read when trying to detect if the stream's first bytes
19-
/// contain an HTTP request. Used in [`Self::new`].
20-
///
21-
/// **WARNING**: Can't be too small, otherwise we end up accepting things like "Foo " as valid
22-
/// HTTP requests.
23-
pub const MINIMAL_HEADER_SIZE: usize = 10;
31+
/// Checks if the given `buffer` contains a prefix of a valid HTTP/1.x or HTTP/2 request.
32+
#[tracing::instrument(level = Level::TRACE, ret)]
33+
pub fn detect(buffer: &[u8]) -> DetectedHttpVersion {
34+
if buffer.starts_with(Self::H2_PREFACE) {
35+
return DetectedHttpVersion::Http(Self::V2);
36+
}
2437

25-
/// Checks if `buffer` contains a prefix of a valid HTTP/1.x request, or if it could be an
26-
/// HTTP/2 request by comparing it with a slice of [`Self::H2_PREFACE`].
27-
///
28-
/// The given `buffer` must contain at least [`Self::MINIMAL_HEADER_SIZE`] bytes, otherwise this
29-
/// function always returns [`None`].
30-
#[tracing::instrument(level = "trace")]
31-
pub fn new(buffer: &[u8]) -> Option<Self> {
3238
let mut empty_headers = [httparse::EMPTY_HEADER; 0];
39+
let mut request = httparse::Request::new(&mut empty_headers);
40+
match request.parse(buffer) {
41+
Ok(Status::Complete(..)) => DetectedHttpVersion::Http(Self::V1),
42+
Ok(Status::Partial) => match request.version {
43+
Some(..) => DetectedHttpVersion::Http(Self::V1),
44+
// If we haven't read enough bytes to consume the HTTP version,
45+
// we're not certain yet.
46+
None => DetectedHttpVersion::Unknown,
47+
},
48+
// We use a zero-length header array,
49+
// so this means we successfully parsed the method, uri and version.
50+
Err(httparse::Error::TooManyHeaders) => DetectedHttpVersion::Http(Self::V1),
51+
Err(..) => DetectedHttpVersion::NotHttp,
52+
}
53+
}
54+
}
55+
56+
/// Output of HTTP version detection on an prefix of an incoming stream.
57+
#[derive(PartialEq, Eq, Debug)]
58+
pub enum DetectedHttpVersion {
59+
/// We're certain that the stream is an HTTP connection.
60+
Http(HttpVersion),
61+
/// We're not sure yet.
62+
Unknown,
63+
/// We're certain that the stream is **not** an HTTP connection.
64+
NotHttp,
65+
}
66+
67+
impl DetectedHttpVersion {
68+
/// If the stream is known to be an HTTP connection,
69+
/// returns its version.
70+
///
71+
/// Otherwise, returns [`None`].
72+
pub fn into_version(self) -> Option<HttpVersion> {
73+
match self {
74+
Self::Http(version) => Some(version),
75+
Self::Unknown => None,
76+
Self::NotHttp => None,
77+
}
78+
}
3379

34-
if buffer.len() < Self::MINIMAL_HEADER_SIZE {
35-
None
36-
} else if buffer == &Self::H2_PREFACE[..Self::MINIMAL_HEADER_SIZE] {
37-
Some(Self::V2)
38-
} else {
39-
match httparse::Request::new(&mut empty_headers).parse(buffer) {
40-
Ok(..) | Err(httparse::Error::TooManyHeaders) => Some(Self::V1),
41-
_ => None,
42-
}
80+
/// Returns whether it's known whether the stream is an HTTP connection or not.
81+
pub fn is_known(&self) -> bool {
82+
match self {
83+
Self::Http(_) => true,
84+
Self::Unknown => false,
85+
Self::NotHttp => true,
4386
}
4487
}
4588
}
89+
90+
/// Attempts to detect HTTP version from the first bytes of a stream.
91+
///
92+
/// Keeps reading data until the timeout elapses or we're certain whether the stream is an HTTP
93+
/// connection or not.
94+
///
95+
/// # Notes
96+
///
97+
/// * The given `timeout` starts elapsing only after we complete the first read.
98+
/// * This function can read arbitrarily large amount of data from the stream. However,
99+
/// [`HttpVersion::detect`] should almost always be able to determine the stream type after
100+
/// reading no more than ~2kb (assuming **very** long request URI).
101+
/// * Consumed data is stored in [`RolledBackStream`]'s prefix, which will be dropped after the data
102+
/// is read again.
103+
pub async fn detect_http_version<IO>(
104+
mut stream: IO,
105+
timeout: Duration,
106+
) -> io::Result<(RolledBackStream<IO, BytesMut>, Option<HttpVersion>)>
107+
where
108+
IO: AsyncRead + Unpin,
109+
{
110+
let mut buf = BytesMut::with_capacity(1024);
111+
let mut detected = DetectedHttpVersion::Unknown;
112+
let mut timeout_at: Option<Instant> = None;
113+
114+
while detected.is_known().not() {
115+
let timeout_fut = OptionFuture::from(timeout_at.map(tokio::time::sleep_until));
116+
117+
let result = tokio::select! {
118+
Some(..) = timeout_fut => break,
119+
result = stream.read_buf(&mut buf) => result,
120+
};
121+
122+
let read_size = result?;
123+
if read_size == 0 {
124+
break;
125+
}
126+
127+
timeout_at = timeout_at.or_else(|| Some(Instant::now() + timeout));
128+
detected = HttpVersion::detect(buf.as_ref());
129+
}
130+
131+
Ok((RolledBackStream::new(stream, buf), detected.into_version()))
132+
}
133+
134+
#[cfg(test)]
135+
mod test {
136+
use rstest::rstest;
137+
138+
use super::{DetectedHttpVersion, HttpVersion};
139+
140+
#[rstest]
141+
#[case::known_bug(b"hello ther", DetectedHttpVersion::Unknown)]
142+
#[case::http2(b"PRI * HTTP/2.0", DetectedHttpVersion::Http(HttpVersion::V2))]
143+
#[case::http11_full(b"GET / HTTP/1.1\r\n\r\n", DetectedHttpVersion::Http(HttpVersion::V1))]
144+
#[case::http10_full(b"GET / HTTP/1.0\r\n\r\n", DetectedHttpVersion::Http(HttpVersion::V1))]
145+
#[case::custom_method(b"FOO / HTTP/1.1\r\n\r\n", DetectedHttpVersion::Http(HttpVersion::V1))]
146+
#[case::extra_spaces(b"GET / asd d HTTP/1.1\r\n\r\n", DetectedHttpVersion::NotHttp)]
147+
#[case::bad_version_1(b"GET / HTTP/a\r\n\r\n", DetectedHttpVersion::NotHttp)]
148+
#[case::bad_version_2(b"GET / HTTP/2\r\n\r\n", DetectedHttpVersion::NotHttp)]
149+
#[test]
150+
fn http_detect(#[case] input: &[u8], #[case] expected: DetectedHttpVersion) {
151+
let detected = HttpVersion::detect(input);
152+
assert_eq!(detected, expected,)
153+
}
154+
}

mirrord/agent/src/sniffer.rs

+1-5
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,7 @@ impl TcpPacketData {
8888
/// connection.
8989
#[tracing::instrument(level = Level::TRACE, ret)]
9090
fn treat_as_new_session(&self) -> bool {
91-
self.is_new_connection()
92-
|| matches!(
93-
HttpVersion::new(&self.bytes),
94-
Some(HttpVersion::V1 | HttpVersion::V2)
95-
)
91+
self.is_new_connection() || HttpVersion::detect(&self.bytes).into_version().is_some()
9692
}
9793
}
9894

mirrord/agent/src/steal/connections.rs

+10-11
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,15 @@ use tracing::Level;
2424

2525
use self::{filtered::DynamicBody, unfiltered::UnfilteredStealTask};
2626
use super::{
27-
http::DefaultReversibleStream,
2827
subscriptions::PortSubscription,
2928
tls::{self, error::StealTlsSetupError, StealTlsHandlerStore},
3029
};
3130
use crate::{
32-
http::HttpVersion, incoming::RedirectedConnection,
31+
http::{detect_http_version, HttpVersion},
32+
incoming::RedirectedConnection,
3333
metrics::STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION,
34-
steal::connections::filtered::FilteredStealTask, util::ClientId,
34+
steal::connections::filtered::FilteredStealTask,
35+
util::ClientId,
3536
};
3637

3738
mod filtered;
@@ -513,11 +514,10 @@ impl ConnectionTask {
513514
};
514515

515516
let Some(tls_handler) = tls_handler else {
516-
let mut stream =
517-
DefaultReversibleStream::read_header(connection, Self::HTTP_DETECTION_TIMEOUT)
518-
.await?;
517+
let (mut stream, http_version) =
518+
detect_http_version(connection, Self::HTTP_DETECTION_TIMEOUT).await?;
519519

520-
let Some(http_version) = HttpVersion::new(stream.get_header()) else {
520+
let Some(http_version) = http_version else {
521521
tracing::trace!("No HTTP version detected, proxying the connection transparently");
522522

523523
let mut outgoing_io = TcpStream::connect(destination).await?;
@@ -590,11 +590,10 @@ impl ConnectionTask {
590590
return Ok(());
591591
}
592592
None => {
593-
let mut stream =
594-
DefaultReversibleStream::read_header(tls_stream, Self::HTTP_DETECTION_TIMEOUT)
595-
.await?;
593+
let (mut stream, http_version) =
594+
detect_http_version(tls_stream, Self::HTTP_DETECTION_TIMEOUT).await?;
596595

597-
let Some(http_version) = HttpVersion::new(stream.get_header()) else {
596+
let Some(http_version) = http_version else {
598597
tracing::trace!(
599598
"No HTTP version detected, proxying the connection transparently"
600599
);

mirrord/agent/src/steal/http.rs

-8
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,7 @@
11
//! Utils related to stealing with an HTTP filter.
22
3-
use crate::http::HttpVersion;
4-
53
mod filter;
64
mod response_fallback;
7-
mod reversible_stream;
85

96
pub(crate) use filter::HttpFilter;
107
pub(crate) use response_fallback::{HttpResponseFallback, ReceiverStreamBody};
11-
pub(crate) use reversible_stream::ReversibleStream;
12-
13-
/// Handy alias due to [`ReversibleStream`] being generic, avoiding value mismatches.
14-
pub(crate) type DefaultReversibleStream<IO> =
15-
ReversibleStream<{ HttpVersion::MINIMAL_HEADER_SIZE }, IO>;

0 commit comments

Comments
 (0)