Skip to content

Better HTTP detection in the agent #3307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/3296.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improved HTTP detection logic in the mirrord agent.
157 changes: 133 additions & 24 deletions mirrord/agent/src/http.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
use std::{io, ops::Not, time::Duration};

use bytes::BytesMut;
use futures::future::OptionFuture;
use httparse::Status;
use tokio::{
io::{AsyncRead, AsyncReadExt},
time::Instant,
};
use tracing::Level;

use crate::util::rolledback_stream::RolledBackStream;

/// Helper enum for representing HTTP/1.x and HTTP/2, which are handled very differently in some
/// parts of the code.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
Expand All @@ -12,34 +25,130 @@ pub enum HttpVersion {
impl HttpVersion {
/// Default start of an HTTP/2 request.
///
/// Used in [`Self::new`] to check if the connection should be treated as HTTP/2.
pub const H2_PREFACE: &'static [u8; 14] = b"PRI * HTTP/2.0";
/// Used in [`Self::detect`] to check if the connection should be treated as HTTP/2.
const H2_PREFACE: &'static [u8; 14] = b"PRI * HTTP/2.0";

/// Controls the amount of data we read when trying to detect if the stream's first bytes
/// contain an HTTP request. Used in [`Self::new`].
///
/// **WARNING**: Can't be too small, otherwise we end up accepting things like "Foo " as valid
/// HTTP requests.
pub const MINIMAL_HEADER_SIZE: usize = 10;
/// Checks if the given `buffer` contains a prefix of a valid HTTP/1.x or HTTP/2 request.
#[tracing::instrument(level = Level::TRACE, ret)]
pub fn detect(buffer: &[u8]) -> DetectedHttpVersion {
if buffer.starts_with(Self::H2_PREFACE) {
return DetectedHttpVersion::Http(Self::V2);
}

/// Checks if `buffer` contains a prefix of a valid HTTP/1.x request, or if it could be an
/// HTTP/2 request by comparing it with a slice of [`Self::H2_PREFACE`].
///
/// The given `buffer` must contain at least [`Self::MINIMAL_HEADER_SIZE`] bytes, otherwise this
/// function always returns [`None`].
#[tracing::instrument(level = "trace")]
pub fn new(buffer: &[u8]) -> Option<Self> {
let mut empty_headers = [httparse::EMPTY_HEADER; 0];
let mut request = httparse::Request::new(&mut empty_headers);
match request.parse(buffer) {
Ok(Status::Complete(..)) => DetectedHttpVersion::Http(Self::V1),
Ok(Status::Partial) => match request.version {
Some(..) => DetectedHttpVersion::Http(Self::V1),
// If we haven't read enough bytes to consume the HTTP version,
// we're not certain yet.
None => DetectedHttpVersion::Unknown,
},
// We use a zero-length header array,
// so this means we successfully parsed the method, uri and version.
Err(httparse::Error::TooManyHeaders) => DetectedHttpVersion::Http(Self::V1),
Err(..) => DetectedHttpVersion::NotHttp,
}
}
}

/// Output of HTTP version detection on an prefix of an incoming stream.
#[derive(PartialEq, Eq, Debug)]
pub enum DetectedHttpVersion {
/// We're certain that the stream is an HTTP connection.
Http(HttpVersion),
/// We're not sure yet.
Unknown,
/// We're certain that the stream is **not** an HTTP connection.
NotHttp,
}

impl DetectedHttpVersion {
/// If the stream is known to be an HTTP connection,
/// returns its version.
///
/// Otherwise, returns [`None`].
pub fn into_version(self) -> Option<HttpVersion> {
match self {
Self::Http(version) => Some(version),
Self::Unknown => None,
Self::NotHttp => None,
}
}

if buffer.len() < Self::MINIMAL_HEADER_SIZE {
None
} else if buffer == &Self::H2_PREFACE[..Self::MINIMAL_HEADER_SIZE] {
Some(Self::V2)
} else {
match httparse::Request::new(&mut empty_headers).parse(buffer) {
Ok(..) | Err(httparse::Error::TooManyHeaders) => Some(Self::V1),
_ => None,
}
/// Returns whether it's known whether the stream is an HTTP connection or not.
pub fn is_known(&self) -> bool {
match self {
Self::Http(_) => true,
Self::Unknown => false,
Self::NotHttp => true,
}
}
}

/// Attempts to detect HTTP version from the first bytes of a stream.
///
/// Keeps reading data until the timeout elapses or we're certain whether the stream is an HTTP
/// connection or not.
///
/// # Notes
///
/// * The given `timeout` starts elapsing only after we complete the first read.
/// * This function can read arbitrarily large amount of data from the stream. However,
/// [`HttpVersion::detect`] should almost always be able to determine the stream type after
/// reading no more than ~2kb (assuming **very** long request URI).
/// * Consumed data is stored in [`RolledBackStream`]'s prefix, which will be dropped after the data
/// is read again.
pub async fn detect_http_version<IO>(
mut stream: IO,
timeout: Duration,
) -> io::Result<(RolledBackStream<IO, BytesMut>, Option<HttpVersion>)>
where
IO: AsyncRead + Unpin,
{
let mut buf = BytesMut::with_capacity(1024);
let mut detected = DetectedHttpVersion::Unknown;
let mut timeout_at: Option<Instant> = None;

while detected.is_known().not() {
let timeout_fut = OptionFuture::from(timeout_at.map(tokio::time::sleep_until));

let result = tokio::select! {
Some(..) = timeout_fut => break,
result = stream.read_buf(&mut buf) => result,
};

let read_size = result?;
if read_size == 0 {
break;
}

timeout_at = timeout_at.or_else(|| Some(Instant::now() + timeout));
detected = HttpVersion::detect(buf.as_ref());
}

Ok((RolledBackStream::new(stream, buf), detected.into_version()))
}

#[cfg(test)]
mod test {
use rstest::rstest;

use super::{DetectedHttpVersion, HttpVersion};

#[rstest]
#[case::known_bug(b"hello ther", DetectedHttpVersion::Unknown)]
#[case::http2(b"PRI * HTTP/2.0", DetectedHttpVersion::Http(HttpVersion::V2))]
#[case::http11_full(b"GET / HTTP/1.1\r\n\r\n", DetectedHttpVersion::Http(HttpVersion::V1))]
#[case::http10_full(b"GET / HTTP/1.0\r\n\r\n", DetectedHttpVersion::Http(HttpVersion::V1))]
#[case::custom_method(b"FOO / HTTP/1.1\r\n\r\n", DetectedHttpVersion::Http(HttpVersion::V1))]
#[case::extra_spaces(b"GET / asd d HTTP/1.1\r\n\r\n", DetectedHttpVersion::NotHttp)]
#[case::bad_version_1(b"GET / HTTP/a\r\n\r\n", DetectedHttpVersion::NotHttp)]
#[case::bad_version_2(b"GET / HTTP/2\r\n\r\n", DetectedHttpVersion::NotHttp)]
#[test]
fn http_detect(#[case] input: &[u8], #[case] expected: DetectedHttpVersion) {
let detected = HttpVersion::detect(input);
assert_eq!(detected, expected,)
}
}
6 changes: 1 addition & 5 deletions mirrord/agent/src/sniffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,7 @@ impl TcpPacketData {
/// connection.
#[tracing::instrument(level = Level::TRACE, ret)]
fn treat_as_new_session(&self) -> bool {
self.is_new_connection()
|| matches!(
HttpVersion::new(&self.bytes),
Some(HttpVersion::V1 | HttpVersion::V2)
)
self.is_new_connection() || HttpVersion::detect(&self.bytes).into_version().is_some()
}
}

Expand Down
21 changes: 10 additions & 11 deletions mirrord/agent/src/steal/connections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ use tracing::Level;

use self::{filtered::DynamicBody, unfiltered::UnfilteredStealTask};
use super::{
http::DefaultReversibleStream,
subscriptions::PortSubscription,
tls::{self, error::StealTlsSetupError, StealTlsHandlerStore},
};
use crate::{
http::HttpVersion, incoming::RedirectedConnection,
http::{detect_http_version, HttpVersion},
incoming::RedirectedConnection,
metrics::STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION,
steal::connections::filtered::FilteredStealTask, util::ClientId,
steal::connections::filtered::FilteredStealTask,
util::ClientId,
};

mod filtered;
Expand Down Expand Up @@ -513,11 +514,10 @@ impl ConnectionTask {
};

let Some(tls_handler) = tls_handler else {
let mut stream =
DefaultReversibleStream::read_header(connection, Self::HTTP_DETECTION_TIMEOUT)
.await?;
let (mut stream, http_version) =
detect_http_version(connection, Self::HTTP_DETECTION_TIMEOUT).await?;

let Some(http_version) = HttpVersion::new(stream.get_header()) else {
let Some(http_version) = http_version else {
tracing::trace!("No HTTP version detected, proxying the connection transparently");

let mut outgoing_io = TcpStream::connect(destination).await?;
Expand Down Expand Up @@ -590,11 +590,10 @@ impl ConnectionTask {
return Ok(());
}
None => {
let mut stream =
DefaultReversibleStream::read_header(tls_stream, Self::HTTP_DETECTION_TIMEOUT)
.await?;
let (mut stream, http_version) =
detect_http_version(tls_stream, Self::HTTP_DETECTION_TIMEOUT).await?;

let Some(http_version) = HttpVersion::new(stream.get_header()) else {
let Some(http_version) = http_version else {
tracing::trace!(
"No HTTP version detected, proxying the connection transparently"
);
Expand Down
8 changes: 0 additions & 8 deletions mirrord/agent/src/steal/http.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
//! Utils related to stealing with an HTTP filter.

use crate::http::HttpVersion;

mod filter;
mod response_fallback;
mod reversible_stream;

pub(crate) use filter::HttpFilter;
pub(crate) use response_fallback::{HttpResponseFallback, ReceiverStreamBody};
pub(crate) use reversible_stream::ReversibleStream;

/// Handy alias due to [`ReversibleStream`] being generic, avoiding value mismatches.
pub(crate) type DefaultReversibleStream<IO> =
ReversibleStream<{ HttpVersion::MINIMAL_HEADER_SIZE }, IO>;
Loading
Loading