Skip to content

Commit

Permalink
update WriteTo to make a single .Write call
Browse files Browse the repository at this point in the history
  • Loading branch information
mastercactapus committed Sep 22, 2020
1 parent 33bfbdd commit ae65f2c
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 29 deletions.
35 changes: 35 additions & 0 deletions buffer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package proxyprotocol

import "io"

type buffer struct {
data []byte
size int
}

func newBuffer(pos, cap int) *buffer {
return &buffer{
size: pos,
data: make([]byte, pos, cap),
}
}

func (b *buffer) Seek(pos int) {
b.data = b.data[:pos]
if pos > b.size {
b.size = pos
}
}
func (b *buffer) Write(p []byte) (int, error) {
b.data = append(b.data, p...)
l := len(b.data)
if l > b.size {
b.size = l
}
return len(p), nil
}
func (b *buffer) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(b.data[:b.size])
return int64(n), err
}
func (b *buffer) Len() int { return b.size }
19 changes: 2 additions & 17 deletions cmd/proxy-get/main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"bytes"
"flag"
"fmt"
"io"
Expand Down Expand Up @@ -90,14 +89,7 @@ func main() {
DestPort: d.Port,
}

buf := new(bytes.Buffer)
_, err = hdr.WriteTo(buf)
if err != nil {
c.Close()
return nil, fmt.Errorf("write v1 header: %w", err)
}

_, err = c.Write(buf.Bytes())
_, err = hdr.WriteTo(c)
if err != nil {
c.Close()
return nil, fmt.Errorf("write v1 header: %w", err)
Expand All @@ -123,14 +115,7 @@ func main() {
hdr.Command = proxyprotocol.CmdLocal
}

buf := new(bytes.Buffer)
_, err = hdr.WriteTo(buf)
if err != nil {
c.Close()
return nil, fmt.Errorf("write v2 header: %w", err)
}

_, err = c.Write(buf.Bytes())
_, err = hdr.WriteTo(c)
if err != nil {
c.Close()
return nil, fmt.Errorf("write v2 header: %w", err)
Expand Down
23 changes: 11 additions & 12 deletions headerv2.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func (h HeaderV2) WriteTo(w io.Writer) (int64, error) {
return sendEmpty()
}

addr := make([]byte, 216)
buf := newBuffer(16, 232)

setAddr := func(srcIP, dstIP net.IP, srcPort, dstPort int) (fam byte) {
src := srcIP.To4()
Expand All @@ -210,12 +210,12 @@ func (h HeaderV2) WriteTo(w io.Writer) (int64, error) {
if src == nil || dst == nil {
return 0 // UNSPEC
}
buf := bytes.NewBuffer(addr[:0])

buf.Write(src)
buf.Write(dst)
binary.Write(buf, binary.BigEndian, uint16(srcPort))
binary.Write(buf, binary.BigEndian, uint16(dstPort))
addr = buf.Bytes()

return fam
}

Expand Down Expand Up @@ -257,20 +257,19 @@ func (h HeaderV2) WriteTo(w io.Writer) (int64, error) {
default:
return sendEmpty()
}
copy(addr, src.Name)
copy(addr[108:], dst.Name)
buf.Write([]byte(src.Name))
buf.Seek(108 + 16)
buf.Write([]byte(dst.Name))
buf.Seek(232)
}

rawHdr.Len = uint16(len(addr))
rawHdr.Len = uint16(buf.Len() - 16)

err := binary.Write(w, binary.BigEndian, rawHdr)
buf.Seek(0)
err := binary.Write(buf, binary.BigEndian, rawHdr)
if err != nil {
return 0, err
}

n, err := w.Write(addr)
if err != nil {
return int64(16 + n), err
}
return int64(16 + n), err
return buf.WriteTo(w)
}

0 comments on commit ae65f2c

Please sign in to comment.