1
1
use futures:: future:: FutureExt ;
2
2
use futures:: StreamExt ;
3
3
use rand:: prelude:: * ;
4
+ use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
4
5
use std:: sync:: Arc ;
5
6
use thiserror:: Error ;
6
7
8
+ use crate :: tcp_stream:: CustomTcpStream ;
7
9
use crate :: ConnectToEntry ;
8
10
9
11
#[ derive( Debug , Clone ) ]
@@ -128,6 +130,7 @@ impl ClientBuilder {
128
130
rng : rand:: rngs:: StdRng :: from_entropy ( ) ,
129
131
} ,
130
132
client : None ,
133
+ read_bytes_counter : Arc :: new ( AtomicUsize :: new ( 0 ) ) ,
131
134
timeout : self . timeout ,
132
135
http_version : self . http_version ,
133
136
redirect_limit : self . redirect_limit ,
@@ -195,6 +198,7 @@ pub struct Client {
195
198
body : Option < & ' static [ u8 ] > ,
196
199
dns : DNS ,
197
200
client : Option < hyper:: client:: conn:: SendRequest < hyper:: Body > > ,
201
+ read_bytes_counter : Arc < AtomicUsize > ,
198
202
timeout : Option < std:: time:: Duration > ,
199
203
redirect_limit : usize ,
200
204
disable_keepalive : bool ,
@@ -211,6 +215,7 @@ impl Client {
211
215
} else {
212
216
let stream = tokio:: net:: TcpStream :: connect ( addr) . await ?;
213
217
stream. set_nodelay ( true ) ?;
218
+ let stream = CustomTcpStream :: new ( stream, self . read_bytes_counter . clone ( ) ) ;
214
219
// stream.set_keepalive(std::time::Duration::from_secs(1).into())?;
215
220
let ( send, conn) = hyper:: client:: conn:: handshake ( stream) . await ?;
216
221
tokio:: spawn ( conn) ;
@@ -225,6 +230,7 @@ impl Client {
225
230
) -> Result < hyper:: client:: conn:: SendRequest < hyper:: Body > , ClientError > {
226
231
let stream = tokio:: net:: TcpStream :: connect ( addr) . await ?;
227
232
stream. set_nodelay ( true ) ?;
233
+ let stream = CustomTcpStream :: new ( stream, self . read_bytes_counter . clone ( ) ) ;
228
234
229
235
let connector = if self . insecure {
230
236
native_tls:: TlsConnector :: builder ( )
@@ -251,6 +257,7 @@ impl Client {
251
257
) -> Result < hyper:: client:: conn:: SendRequest < hyper:: Body > , ClientError > {
252
258
let stream = tokio:: net:: TcpStream :: connect ( addr) . await ?;
253
259
stream. set_nodelay ( true ) ?;
260
+ let stream = CustomTcpStream :: new ( stream, self . read_bytes_counter . clone ( ) ) ;
254
261
255
262
let mut root_cert_store = rustls:: RootCertStore :: empty ( ) ;
256
263
for cert in rustls_native_certs:: load_native_certs ( ) ? {
@@ -331,19 +338,17 @@ impl Client {
331
338
connection_time = Some ( ConnectionTime { dns_lookup, dialup } ) ;
332
339
}
333
340
let request = self . request ( & self . url ) ?;
341
+ self . read_bytes_counter . store ( 0 , Ordering :: Relaxed ) ;
334
342
match send_request. send_request ( request) . await {
335
343
Ok ( res) => {
336
344
let ( parts, mut stream) = res. into_parts ( ) ;
337
345
let mut status = parts. status ;
338
346
339
- let mut len_sum = 0 ;
340
- while let Some ( chunk) = stream. next ( ) . await {
341
- len_sum += chunk?. len ( ) ;
342
- }
347
+ while stream. next ( ) . await . is_some ( ) { }
343
348
344
349
if self . redirect_limit != 0 {
345
350
if let Some ( location) = parts. headers . get ( "Location" ) {
346
- let ( send_request_redirect, new_status, len ) = self
351
+ let ( send_request_redirect, new_status) = self
347
352
. redirect (
348
353
send_request,
349
354
& self . url . clone ( ) ,
@@ -354,7 +359,6 @@ impl Client {
354
359
355
360
send_request = send_request_redirect;
356
361
status = new_status;
357
- len_sum = len;
358
362
}
359
363
}
360
364
@@ -364,7 +368,7 @@ impl Client {
364
368
start,
365
369
end,
366
370
status,
367
- len_bytes : len_sum ,
371
+ len_bytes : self . read_bytes_counter . load ( Ordering :: Relaxed ) ,
368
372
connection_time,
369
373
} ;
370
374
@@ -404,7 +408,6 @@ impl Client {
404
408
(
405
409
hyper:: client:: conn:: SendRequest < hyper:: Body > ,
406
410
http:: StatusCode ,
407
- usize ,
408
411
) ,
409
412
ClientError ,
410
413
> ,
@@ -451,28 +454,25 @@ impl Client {
451
454
) ?,
452
455
) ;
453
456
}
457
+ self . read_bytes_counter . store ( 0 , Ordering :: Relaxed ) ;
454
458
let res = send_request. send_request ( request) . await ?;
455
459
let ( parts, mut stream) = res. into_parts ( ) ;
456
460
let mut status = parts. status ;
457
461
458
- let mut len_sum = 0 ;
459
- while let Some ( chunk) = stream. next ( ) . await {
460
- len_sum += chunk?. len ( ) ;
461
- }
462
+ while stream. next ( ) . await . is_some ( ) { }
462
463
463
464
if let Some ( location) = parts. headers . get ( "Location" ) {
464
- let ( send_request_redirect, new_status, len ) = self
465
+ let ( send_request_redirect, new_status) = self
465
466
. redirect ( send_request, & url, location, limit - 1 )
466
467
. await ?;
467
468
send_request = send_request_redirect;
468
469
status = new_status;
469
- len_sum = len;
470
470
}
471
471
472
472
if let Some ( send_request_base) = send_request_base {
473
- Ok ( ( send_request_base, status, len_sum ) )
473
+ Ok ( ( send_request_base, status) )
474
474
} else {
475
- Ok ( ( send_request, status, len_sum ) )
475
+ Ok ( ( send_request, status) )
476
476
}
477
477
}
478
478
. boxed ( )
@@ -546,7 +546,6 @@ pub async fn work(
546
546
n_tasks : usize ,
547
547
n_workers : usize ,
548
548
) {
549
- use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
550
549
let counter = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
551
550
552
551
let futures = ( 0 ..n_workers)
0 commit comments