diff --git a/header.go b/header.go index 209c2cc..791e94e 100644 --- a/header.go +++ b/header.go @@ -42,6 +42,7 @@ type Header struct { SourceAddr net.Addr DestinationAddr net.Addr rawTLVs []byte + length int } // HeaderProxyFromAddrs creates a new PROXY header from a source and a diff --git a/protocol.go b/protocol.go index 7eda3d6..549b538 100644 --- a/protocol.go +++ b/protocol.go @@ -2,6 +2,7 @@ package proxyproto import ( "bufio" + "bytes" "errors" "fmt" "io" @@ -51,7 +52,7 @@ type Conn struct { once sync.Once readErr error conn net.Conn - bufReader *bufio.Reader + reader io.Reader header *Header ProxyHeaderPolicy Policy Validate Validator @@ -150,14 +151,8 @@ func (p *Listener) Addr() net.Addr { // NewConn is used to wrap a net.Conn that may be speaking // the proxy protocol into a proxyproto.Conn func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn { - // For v1 the header length is at most 108 bytes. - // For v2 the header length is at most 52 bytes plus the length of the TLVs. - // We use 256 bytes to be safe. - const bufSize = 256 - pConn := &Conn{ - bufReader: bufio.NewReaderSize(conn, bufSize), - conn: conn, + conn: conn, } for _, opt := range opts { @@ -178,7 +173,7 @@ func (p *Conn) Read(b []byte) (int, error) { return 0, p.readErr } - return p.bufReader.Read(b) + return p.reader.Read(b) } // Write wraps original conn.Write @@ -294,7 +289,21 @@ func (p *Conn) readHeader() error { } } - header, err := Read(p.bufReader) + // For v1 the header length is at most 108 bytes. + // For v2 the header length is at most 52 bytes plus the length of the TLVs. + // We use 256 bytes to be safe. + const bufSize = 256 + + bb := bytes.NewBuffer(make([]byte, 0, bufSize)) + tr := io.TeeReader(p.conn, bb) + br := bufio.NewReaderSize(tr, bufSize) + + header, err := Read(br) + + if err == nil { + _, err = io.CopyN(io.Discard, bb, int64(header.length)) + } + p.reader = io.MultiReader(bb, p.conn) // If the connection's readHeaderTimeout is more than 0, undo the change to the // deadline that we made above. Because we retain the readDeadline as part of our @@ -360,5 +369,5 @@ func (p *Conn) WriteTo(w io.Writer) (int64, error) { if p.readErr != nil { return 0, p.readErr } - return p.bufReader.WriteTo(w) + return io.Copy(w, p.reader) } diff --git a/v1.go b/v1.go index 0d34ba5..3c948eb 100644 --- a/v1.go +++ b/v1.go @@ -125,6 +125,7 @@ func parseVersion1(reader *bufio.Reader) (*Header, error) { // Command doesn't exist in v1 but set it for other parts of this library // to rely on it for determining connection details. header := initVersion1() + header.length = len(buf) // Transport protocol has been processed already. header.TransportProtocol = transportProtocol diff --git a/v2.go b/v2.go index 74bf3f0..2dde097 100644 --- a/v2.go +++ b/v2.go @@ -100,6 +100,7 @@ func parseVersion2(reader *bufio.Reader) (header *Header, err error) { if !header.validateLength(length) { return nil, ErrInvalidLength } + header.length = 16 + int(length) // Return early if the length is zero, which means that // there's no address information and TLVs present for UNSPEC.