1
+ use std:: { io, ops:: Not , time:: Duration } ;
2
+
3
+ use bytes:: BytesMut ;
4
+ use futures:: future:: OptionFuture ;
5
+ use httparse:: Status ;
6
+ use tokio:: {
7
+ io:: { AsyncRead , AsyncReadExt } ,
8
+ time:: Instant ,
9
+ } ;
10
+ use tracing:: Level ;
11
+
12
+ use crate :: util:: rolledback_stream:: RolledBackStream ;
13
+
1
14
/// Helper enum for representing HTTP/1.x and HTTP/2, which are handled very differently in some
2
15
/// parts of the code.
3
16
#[ derive( Debug , Clone , Copy , PartialEq , Eq , PartialOrd , Ord ) ]
@@ -12,34 +25,130 @@ pub enum HttpVersion {
12
25
impl HttpVersion {
13
26
/// Default start of an HTTP/2 request.
14
27
///
15
- /// Used in [`Self::new `] to check if the connection should be treated as HTTP/2.
16
- pub const H2_PREFACE : & ' static [ u8 ; 14 ] = b"PRI * HTTP/2.0" ;
28
+ /// Used in [`Self::detect `] to check if the connection should be treated as HTTP/2.
29
+ const H2_PREFACE : & ' static [ u8 ; 14 ] = b"PRI * HTTP/2.0" ;
17
30
18
- /// Controls the amount of data we read when trying to detect if the stream's first bytes
19
- /// contain an HTTP request. Used in [`Self::new`].
20
- ///
21
- /// **WARNING**: Can't be too small, otherwise we end up accepting things like "Foo " as valid
22
- /// HTTP requests.
23
- pub const MINIMAL_HEADER_SIZE : usize = 10 ;
31
+ /// Checks if the given `buffer` contains a prefix of a valid HTTP/1.x or HTTP/2 request.
32
+ # [ tracing :: instrument ( level = Level :: TRACE , ret ) ]
33
+ pub fn detect ( buffer : & [ u8 ] ) -> DetectedHttpVersion {
34
+ if buffer . starts_with ( Self :: H2_PREFACE ) {
35
+ return DetectedHttpVersion :: Http ( Self :: V2 ) ;
36
+ }
24
37
25
- /// Checks if `buffer` contains a prefix of a valid HTTP/1.x request, or if it could be an
26
- /// HTTP/2 request by comparing it with a slice of [`Self::H2_PREFACE`].
27
- ///
28
- /// The given `buffer` must contain at least [`Self::MINIMAL_HEADER_SIZE`] bytes, otherwise this
29
- /// function always returns [`None`].
30
- #[ tracing:: instrument( level = "trace" ) ]
31
- pub fn new ( buffer : & [ u8 ] ) -> Option < Self > {
32
38
let mut empty_headers = [ httparse:: EMPTY_HEADER ; 0 ] ;
39
+ let mut request = httparse:: Request :: new ( & mut empty_headers) ;
40
+ match request. parse ( buffer) {
41
+ Ok ( Status :: Complete ( ..) ) => DetectedHttpVersion :: Http ( Self :: V1 ) ,
42
+ Ok ( Status :: Partial ) => match request. version {
43
+ Some ( ..) => DetectedHttpVersion :: Http ( Self :: V1 ) ,
44
+ // If we haven't read enough bytes to consume the HTTP version,
45
+ // we're not certain yet.
46
+ None => DetectedHttpVersion :: Unknown ,
47
+ } ,
48
+ // We use a zero-length header array,
49
+ // so this means we successfully parsed the method, uri and version.
50
+ Err ( httparse:: Error :: TooManyHeaders ) => DetectedHttpVersion :: Http ( Self :: V1 ) ,
51
+ Err ( ..) => DetectedHttpVersion :: NotHttp ,
52
+ }
53
+ }
54
+ }
55
+
56
+ /// Output of HTTP version detection on an prefix of an incoming stream.
57
+ #[ derive( PartialEq , Eq , Debug ) ]
58
+ pub enum DetectedHttpVersion {
59
+ /// We're certain that the stream is an HTTP connection.
60
+ Http ( HttpVersion ) ,
61
+ /// We're not sure yet.
62
+ Unknown ,
63
+ /// We're certain that the stream is **not** an HTTP connection.
64
+ NotHttp ,
65
+ }
66
+
67
+ impl DetectedHttpVersion {
68
+ /// If the stream is known to be an HTTP connection,
69
+ /// returns its version.
70
+ ///
71
+ /// Otherwise, returns [`None`].
72
+ pub fn into_version ( self ) -> Option < HttpVersion > {
73
+ match self {
74
+ Self :: Http ( version) => Some ( version) ,
75
+ Self :: Unknown => None ,
76
+ Self :: NotHttp => None ,
77
+ }
78
+ }
33
79
34
- if buffer. len ( ) < Self :: MINIMAL_HEADER_SIZE {
35
- None
36
- } else if buffer == & Self :: H2_PREFACE [ ..Self :: MINIMAL_HEADER_SIZE ] {
37
- Some ( Self :: V2 )
38
- } else {
39
- match httparse:: Request :: new ( & mut empty_headers) . parse ( buffer) {
40
- Ok ( ..) | Err ( httparse:: Error :: TooManyHeaders ) => Some ( Self :: V1 ) ,
41
- _ => None ,
42
- }
80
+ /// Returns whether it's known whether the stream is an HTTP connection or not.
81
+ pub fn is_known ( & self ) -> bool {
82
+ match self {
83
+ Self :: Http ( _) => true ,
84
+ Self :: Unknown => false ,
85
+ Self :: NotHttp => true ,
43
86
}
44
87
}
45
88
}
89
+
90
+ /// Attempts to detect HTTP version from the first bytes of a stream.
91
+ ///
92
+ /// Keeps reading data until the timeout elapses or we're certain whether the stream is an HTTP
93
+ /// connection or not.
94
+ ///
95
+ /// # Notes
96
+ ///
97
+ /// * The given `timeout` starts elapsing only after we complete the first read.
98
+ /// * This function can read arbitrarily large amount of data from the stream. However,
99
+ /// [`HttpVersion::detect`] should almost always be able to determine the stream type after
100
+ /// reading no more than ~2kb (assuming **very** long request URI).
101
+ /// * Consumed data is stored in [`RolledBackStream`]'s prefix, which will be dropped after the data
102
+ /// is read again.
103
+ pub async fn detect_http_version < IO > (
104
+ mut stream : IO ,
105
+ timeout : Duration ,
106
+ ) -> io:: Result < ( RolledBackStream < IO , BytesMut > , Option < HttpVersion > ) >
107
+ where
108
+ IO : AsyncRead + Unpin ,
109
+ {
110
+ let mut buf = BytesMut :: with_capacity ( 1024 ) ;
111
+ let mut detected = DetectedHttpVersion :: Unknown ;
112
+ let mut timeout_at: Option < Instant > = None ;
113
+
114
+ while detected. is_known ( ) . not ( ) {
115
+ let timeout_fut = OptionFuture :: from ( timeout_at. map ( tokio:: time:: sleep_until) ) ;
116
+
117
+ let result = tokio:: select! {
118
+ Some ( ..) = timeout_fut => break ,
119
+ result = stream. read_buf( & mut buf) => result,
120
+ } ;
121
+
122
+ let read_size = result?;
123
+ if read_size == 0 {
124
+ break ;
125
+ }
126
+
127
+ timeout_at = timeout_at. or_else ( || Some ( Instant :: now ( ) + timeout) ) ;
128
+ detected = HttpVersion :: detect ( buf. as_ref ( ) ) ;
129
+ }
130
+
131
+ Ok ( ( RolledBackStream :: new ( stream, buf) , detected. into_version ( ) ) )
132
+ }
133
+
134
+ #[ cfg( test) ]
135
+ mod test {
136
+ use rstest:: rstest;
137
+
138
+ use super :: { DetectedHttpVersion , HttpVersion } ;
139
+
140
+ #[ rstest]
141
+ #[ case:: known_bug( b"hello ther" , DetectedHttpVersion :: Unknown ) ]
142
+ #[ case:: http2( b"PRI * HTTP/2.0" , DetectedHttpVersion :: Http ( HttpVersion :: V2 ) ) ]
143
+ #[ case:: http11_full( b"GET / HTTP/1.1\r \n \r \n " , DetectedHttpVersion :: Http ( HttpVersion :: V1 ) ) ]
144
+ #[ case:: http10_full( b"GET / HTTP/1.0\r \n \r \n " , DetectedHttpVersion :: Http ( HttpVersion :: V1 ) ) ]
145
+ #[ case:: custom_method( b"FOO / HTTP/1.1\r \n \r \n " , DetectedHttpVersion :: Http ( HttpVersion :: V1 ) ) ]
146
+ #[ case:: extra_spaces( b"GET / asd d HTTP/1.1\r \n \r \n " , DetectedHttpVersion :: NotHttp ) ]
147
+ #[ case:: bad_version_1( b"GET / HTTP/a\r \n \r \n " , DetectedHttpVersion :: NotHttp ) ]
148
+ #[ case:: bad_version_2( b"GET / HTTP/2\r \n \r \n " , DetectedHttpVersion :: NotHttp ) ]
149
+ #[ test]
150
+ fn http_detect ( #[ case] input : & [ u8 ] , #[ case] expected : DetectedHttpVersion ) {
151
+ let detected = HttpVersion :: detect ( input) ;
152
+ assert_eq ! ( detected, expected, )
153
+ }
154
+ }
0 commit comments