Fix: forever read + Feat: listen on unix socket
This commit is contained in:
parent
8ba93372fc
commit
0dcbaf3b86
50
builder.go
Normal file
50
builder.go
Normal file
@ -0,0 +1,50 @@
|
||||
package putxt
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"git.tcp.direct/kayos/common/pool"
|
||||
"github.com/yunginnanet/Rate5"
|
||||
)
|
||||
|
||||
type TermDumpster struct {
|
||||
gzip bool
|
||||
maxSize int64
|
||||
timeout time.Duration
|
||||
handler Handler
|
||||
log Logger
|
||||
Pool pool.BufferFactory
|
||||
*rate5.Limiter
|
||||
}
|
||||
|
||||
func NewTermDumpster(handler Handler) *TermDumpster {
|
||||
td := &TermDumpster{
|
||||
maxSize: 3 << 20,
|
||||
timeout: 5 * time.Second,
|
||||
Limiter: rate5.NewStrictLimiter(60, 5),
|
||||
handler: handler,
|
||||
log: dummyLogger{},
|
||||
Pool: pool.NewBufferFactory(),
|
||||
}
|
||||
return td
|
||||
}
|
||||
|
||||
func (td *TermDumpster) WithGzip() *TermDumpster {
|
||||
td.gzip = true
|
||||
return td
|
||||
}
|
||||
|
||||
func (td *TermDumpster) WithMaxSize(size int64) *TermDumpster {
|
||||
td.maxSize = size
|
||||
return td
|
||||
}
|
||||
|
||||
func (td *TermDumpster) WithTimeout(timeout time.Duration) *TermDumpster {
|
||||
td.timeout = timeout
|
||||
return td
|
||||
}
|
||||
|
||||
func (td *TermDumpster) WithLogger(logger Logger) *TermDumpster {
|
||||
td.log = logger
|
||||
return td
|
||||
}
|
@ -1,38 +0,0 @@
|
||||
package main
|
||||
|
||||
// meant to act as a simple example
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.tcp.direct/kayos/common/squish"
|
||||
|
||||
termbin "git.tcp.direct/kayos/putxt"
|
||||
)
|
||||
|
||||
type handler struct{}
|
||||
|
||||
func (h *handler) Ingest(data []byte) ([]byte, error) {
|
||||
var err error
|
||||
data, err = squish.Gunzip(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.ReplaceAll(string(data), "\n", "") == "ping" {
|
||||
println("got ping, sending pong...")
|
||||
return []byte("pong"), nil
|
||||
}
|
||||
println(string(data))
|
||||
return []byte("invalid request"), errors.New("invalid data")
|
||||
}
|
||||
|
||||
func main() {
|
||||
td := termbin.NewTermDumpster(&handler{}).WithGzip().WithMaxSize(3 << 20).WithTimeout(5 * time.Second)
|
||||
err := td.Listen("127.0.0.1", "8888")
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
return
|
||||
}
|
||||
}
|
5
go.mod
5
go.mod
@ -3,14 +3,15 @@ module git.tcp.direct/kayos/putxt
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
git.tcp.direct/kayos/common v0.7.0
|
||||
git.tcp.direct/kayos/common v0.7.5
|
||||
github.com/yunginnanet/Rate5 v1.1.0
|
||||
golang.org/x/tools v0.1.12
|
||||
inet.af/netaddr v0.0.0-20220617031823-097006376321
|
||||
inet.af/netaddr v0.0.0-20220811202034-502d2d690317
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||
go4.org/intern v0.0.0-20211027215823-ae77deb06f29 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760 // indirect
|
||||
nullprogram.com/x/rng v1.1.0 // indirect
|
||||
)
|
||||
|
10
go.sum
10
go.sum
@ -1,5 +1,5 @@
|
||||
git.tcp.direct/kayos/common v0.7.0 h1:KZDwoCzUiwQaYSWESr080N8wUVyLD27QYgzXgc7LiAQ=
|
||||
git.tcp.direct/kayos/common v0.7.0/go.mod h1:7tMZBVNPLFSZk+JXTA6pgXWpf/XHqYRfT7Q3OziI++Y=
|
||||
git.tcp.direct/kayos/common v0.7.5 h1:a95oIv3pzRwzYaINqFASnXqXOWVWupIVWHcOtTVUOHU=
|
||||
git.tcp.direct/kayos/common v0.7.5/go.mod h1:jVbdX9prBrx9e3aTsNpu643brGVgpLvysl40/F5U2cE=
|
||||
github.com/dvyukov/go-fuzz v0.0.0-20210103155950-6a8e9d1f2415/go.mod h1:11Gm+ccJnvAhCNLlf5+cS9KjtbaD5I5zaZpFMsTHWTw=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
@ -34,5 +34,7 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
inet.af/netaddr v0.0.0-20220617031823-097006376321 h1:B4dC8ySKTQXasnjDTMsoCMf1sQG4WsMej0WXaHxunmU=
|
||||
inet.af/netaddr v0.0.0-20220617031823-097006376321/go.mod h1:OIezDfdzOgFhuw4HuWapWq2e9l0H9tK4F1j+ETRtF3k=
|
||||
inet.af/netaddr v0.0.0-20220811202034-502d2d690317 h1:U2fwK6P2EqmopP/hFLTOAjWTki0qgd4GMJn5X8wOleU=
|
||||
inet.af/netaddr v0.0.0-20220811202034-502d2d690317/go.mod h1:OIezDfdzOgFhuw4HuWapWq2e9l0H9tK4F1j+ETRtF3k=
|
||||
nullprogram.com/x/rng v1.1.0 h1:SMU7DHaQSWtKJNTpNFIFt8Wd/KSmOuSDPXrMFp/UMro=
|
||||
nullprogram.com/x/rng v1.1.0/go.mod h1:glGw6V87vyfawxCzqOABL3WfL95G65az9Z2JZCylCkg=
|
||||
|
20
logger.go
Normal file
20
logger.go
Normal file
@ -0,0 +1,20 @@
|
||||
package putxt
|
||||
|
||||
import "fmt"
|
||||
|
||||
const (
|
||||
MessageRatelimited = "RATELIMIT_REACHED"
|
||||
MessageSizeLimited = "MAX_SIZE_EXCEEDED"
|
||||
MessageBinaryData = "BINARY_DATA_REJECTED"
|
||||
MessageInternalError = "INTERNAL_ERROR"
|
||||
)
|
||||
|
||||
type Logger interface {
|
||||
Printf(format string, v ...interface{})
|
||||
}
|
||||
|
||||
type dummyLogger struct{}
|
||||
|
||||
func (dummyLogger) Printf(format string, v ...interface{}) {
|
||||
_, _ = fmt.Printf(format, v...)
|
||||
}
|
184
main.go
184
main.go
@ -1,184 +0,0 @@
|
||||
package termbin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.tcp.direct/kayos/common/squish"
|
||||
"github.com/yunginnanet/Rate5"
|
||||
"golang.org/x/tools/godoc/util"
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
const (
|
||||
MessageRatelimited = "RATELIMIT_REACHED"
|
||||
MessageSizeLimited = "MAX_SIZE_EXCEEDED"
|
||||
MessageBinaryData = "BINARY_DATA_REJECTED"
|
||||
)
|
||||
|
||||
type TermDumpster struct {
|
||||
gzip bool
|
||||
maxSize int64
|
||||
timeout time.Duration
|
||||
handler Handler
|
||||
log Logger
|
||||
*rate5.Limiter
|
||||
*sync.Pool
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Printf(format string, v ...interface{})
|
||||
}
|
||||
|
||||
type Handler interface {
|
||||
Ingest(data []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
type dummyLogger struct{}
|
||||
|
||||
func (dummyLogger) Printf(format string, v ...interface{}) {
|
||||
_, _ = fmt.Printf(format, v...)
|
||||
}
|
||||
|
||||
func NewTermDumpster(handler Handler) *TermDumpster {
|
||||
td := &TermDumpster{
|
||||
maxSize: 3 << 20,
|
||||
timeout: 5 * time.Second,
|
||||
Limiter: rate5.NewStrictLimiter(60, 5),
|
||||
handler: handler,
|
||||
log: dummyLogger{},
|
||||
}
|
||||
td.Pool = &sync.Pool{
|
||||
New: func() any { return new(bytes.Buffer) },
|
||||
}
|
||||
return td
|
||||
}
|
||||
|
||||
func (td *TermDumpster) WithGzip() *TermDumpster {
|
||||
td.gzip = true
|
||||
return td
|
||||
}
|
||||
|
||||
func (td *TermDumpster) WithMaxSize(size int64) *TermDumpster {
|
||||
td.maxSize = size
|
||||
return td
|
||||
}
|
||||
|
||||
func (td *TermDumpster) WithTimeout(timeout time.Duration) *TermDumpster {
|
||||
td.timeout = timeout
|
||||
return td
|
||||
}
|
||||
|
||||
func (td *TermDumpster) WithLogger(logger Logger) *TermDumpster {
|
||||
td.log = logger
|
||||
return td
|
||||
}
|
||||
|
||||
type termbinClient struct {
|
||||
parent *TermDumpster
|
||||
addr string
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *termbinClient) UniqueKey() string {
|
||||
return c.addr
|
||||
}
|
||||
|
||||
func (td *TermDumpster) newClient(c net.Conn) *termbinClient {
|
||||
cipp, _ := netaddr.ParseIPPort(c.RemoteAddr().String())
|
||||
return &termbinClient{parent: td, addr: cipp.IP().String(), Conn: c}
|
||||
}
|
||||
|
||||
func (c *termbinClient) write(data []byte) {
|
||||
if _, err := c.Write(data); err != nil {
|
||||
c.parent.log.Printf("termbinClient: %s error: %w", c.RemoteAddr().String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *termbinClient) writeString(data string) {
|
||||
c.write([]byte(data))
|
||||
}
|
||||
|
||||
func (td *TermDumpster) accept(c net.Conn) {
|
||||
var (
|
||||
final []byte
|
||||
length int64
|
||||
)
|
||||
client := td.newClient(c)
|
||||
if td.Check(client) {
|
||||
client.writeString(MessageRatelimited)
|
||||
client.Close()
|
||||
td.log.Printf("termbinClient: %s error: %s", client.RemoteAddr().String(), MessageRatelimited)
|
||||
return
|
||||
}
|
||||
buf := td.Pool.Get().(*bytes.Buffer)
|
||||
defer func() {
|
||||
_ = client.Close()
|
||||
buf.Reset()
|
||||
td.Put(buf)
|
||||
}()
|
||||
readLoop:
|
||||
for {
|
||||
if err := client.SetReadDeadline(time.Now().Add(td.timeout)); err != nil {
|
||||
td.log.Printf("failed to set read deadline: %s error: %w", client.RemoteAddr().String(), err)
|
||||
return
|
||||
}
|
||||
n, err := buf.ReadFrom(client)
|
||||
if err != nil {
|
||||
switch err.Error() {
|
||||
case "EOF":
|
||||
break readLoop
|
||||
case "read tcp " + client.LocalAddr().String() + "->" + client.RemoteAddr().String() + ": i/o timeout":
|
||||
break readLoop
|
||||
default:
|
||||
td.log.Printf("termbinClient: %s error: %w", client.RemoteAddr().String(), err)
|
||||
return
|
||||
}
|
||||
}
|
||||
length += n
|
||||
if length > td.maxSize {
|
||||
client.writeString(MessageSizeLimited)
|
||||
return
|
||||
}
|
||||
}
|
||||
if !util.IsText(buf.Bytes()) {
|
||||
client.writeString(MessageBinaryData)
|
||||
return
|
||||
}
|
||||
if td.gzip {
|
||||
if final = squish.Gzip(buf.Bytes()); final == nil {
|
||||
final = buf.Bytes()
|
||||
}
|
||||
}
|
||||
resp, err := td.handler.Ingest(final)
|
||||
if err != nil {
|
||||
if resp == nil {
|
||||
client.writeString("INTERNAL_ERROR")
|
||||
}
|
||||
td.log.Printf("termbinClient: %s error: %w", client.RemoteAddr().String(), err)
|
||||
}
|
||||
_, err = client.Write(resp)
|
||||
if err != nil {
|
||||
td.log.Printf("termbinClient: %s failed to deliver result: %w", client.RemoteAddr().String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
// Listen starts the TCP server
|
||||
func (td *TermDumpster) Listen(addr string, port string) error {
|
||||
l, err := net.Listen("tcp", addr+":"+port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer l.Close()
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
td.log.Printf("Error accepting connection: %s", err.Error())
|
||||
continue
|
||||
}
|
||||
go td.accept(c)
|
||||
}
|
||||
}
|
146
putxt.go
Normal file
146
putxt.go
Normal file
@ -0,0 +1,146 @@
|
||||
package putxt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"git.tcp.direct/kayos/common/squish"
|
||||
"golang.org/x/tools/godoc/util"
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
Ingest(data []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
type termbinClient struct {
|
||||
parent *TermDumpster
|
||||
addr string
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *termbinClient) UniqueKey() string {
|
||||
return c.addr
|
||||
}
|
||||
|
||||
func (td *TermDumpster) newClient(c net.Conn) *termbinClient {
|
||||
cipp, _ := netaddr.ParseIPPort(c.RemoteAddr().String())
|
||||
return &termbinClient{parent: td, addr: cipp.IP().String(), Conn: c}
|
||||
}
|
||||
|
||||
func (c *termbinClient) write(data []byte) {
|
||||
if _, err := c.Write(data); err != nil {
|
||||
c.parent.log.Printf("termbinClient: %s error: %s", c.RemoteAddr().String(), err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *termbinClient) writeString(data string) {
|
||||
c.write([]byte(data))
|
||||
}
|
||||
|
||||
func (td *TermDumpster) accept(c net.Conn) {
|
||||
var final []byte
|
||||
client := td.newClient(c)
|
||||
if td.Check(client) {
|
||||
client.writeString(MessageRatelimited)
|
||||
_ = client.Close()
|
||||
td.log.Printf("termbinClient: %s error: %s", client.RemoteAddr().String(), MessageRatelimited)
|
||||
return
|
||||
}
|
||||
buf := td.Pool.Get()
|
||||
defer func() {
|
||||
_ = client.Close()
|
||||
td.Pool.MustPut(buf)
|
||||
}()
|
||||
if err := client.SetReadDeadline(time.Now().Add(td.timeout)); err != nil {
|
||||
td.log.Printf("failed to set read deadline: %s error: %s", client.RemoteAddr().String(), err.Error())
|
||||
return
|
||||
}
|
||||
readLoop:
|
||||
for {
|
||||
_, err := buf.ReadFrom(client)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
break readLoop
|
||||
case os.IsTimeout(err):
|
||||
break readLoop
|
||||
default:
|
||||
td.log.Printf("termbinClient: %s error: %s", client.RemoteAddr().String(), err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
if int64(buf.Len()) > td.maxSize {
|
||||
client.writeString(MessageSizeLimited)
|
||||
return
|
||||
}
|
||||
}
|
||||
if !util.IsText(buf.Bytes()) {
|
||||
client.writeString(MessageBinaryData)
|
||||
return
|
||||
}
|
||||
if td.gzip {
|
||||
if final = squish.Gzip(buf.Bytes()); final == nil {
|
||||
client.writeString(MessageInternalError)
|
||||
td.log.Printf(
|
||||
"termbinClient: %s error: gzipping data provided empty result",
|
||||
client.RemoteAddr().String())
|
||||
return
|
||||
}
|
||||
}
|
||||
resp, err := td.handler.Ingest(final)
|
||||
if err != nil {
|
||||
if resp == nil {
|
||||
client.writeString(MessageInternalError)
|
||||
}
|
||||
td.log.Printf("termbinClient: %s error: %s", client.RemoteAddr().String(), err.Error())
|
||||
return
|
||||
}
|
||||
_, err = client.Write(resp)
|
||||
if err != nil {
|
||||
td.log.Printf("termbinClient: %s failed to deliver result: %s", client.RemoteAddr().String(), err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (td *TermDumpster) handle(l net.Listener) {
|
||||
defer l.Close()
|
||||
for {
|
||||
c, acceptErr := l.Accept()
|
||||
if acceptErr != nil {
|
||||
td.log.Printf("Error accepting connection: %s", acceptErr.Error())
|
||||
continue
|
||||
}
|
||||
go td.accept(c)
|
||||
}
|
||||
}
|
||||
|
||||
// Listen starts the TCP server
|
||||
func (td *TermDumpster) Listen(addr string, port string) error {
|
||||
l, err := net.Listen("tcp", addr+":"+port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
td.handle(l)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListenUnixSocket starts the unix socket listener
|
||||
func (td *TermDumpster) ListenUnixSocket(path string) error {
|
||||
unixAddr, err := net.ResolveUnixAddr("unix", path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = syscall.Unlink(path)
|
||||
mask := syscall.Umask(0o077)
|
||||
unixListener, err := net.ListenUnix("unix", unixAddr)
|
||||
syscall.Umask(mask)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
td.handle(unixListener)
|
||||
return nil
|
||||
}
|
75
putxt_test.go
Normal file
75
putxt_test.go
Normal file
@ -0,0 +1,75 @@
|
||||
package putxt
|
||||
|
||||
// meant to act as a simple example
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.tcp.direct/kayos/common/entropy"
|
||||
"git.tcp.direct/kayos/common/squish"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
t *testing.T
|
||||
needle []byte
|
||||
}
|
||||
|
||||
func (h *handler) Ingest(data []byte) ([]byte, error) {
|
||||
var err error
|
||||
data, err = squish.Gunzip(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.ReplaceAll(string(data), "\n", "") == string(h.needle) {
|
||||
h.t.Log("got needle, echoing it back...")
|
||||
return h.needle, nil
|
||||
}
|
||||
return []byte("invalid request"), errors.New("data does not match generated test needle: " + string(data))
|
||||
}
|
||||
|
||||
func TestPutxt(t *testing.T) {
|
||||
socketPath := t.TempDir() + "/putxt.sock"
|
||||
testHandler := &handler{t: t, needle: []byte(entropy.RandStr(4096))}
|
||||
td := NewTermDumpster(testHandler).WithGzip().WithMaxSize(3 << 20).WithTimeout(5 * time.Second)
|
||||
var errChan = make(chan error)
|
||||
go func() {
|
||||
err := td.ListenUnixSocket(socketPath)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case err := <-errChan:
|
||||
t.Fatalf("failed to listen on unix socket: %v", err.Error())
|
||||
default:
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
c, err := net.Dial("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to unix socket: %v", err.Error())
|
||||
}
|
||||
defer c.Close()
|
||||
res := make(chan []byte)
|
||||
go func() {
|
||||
buf := make([]byte, 4096)
|
||||
n, err := c.Read(buf)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
res <- buf[:n]
|
||||
}()
|
||||
_, err = c.Write(testHandler.needle)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to unix socket: %v", err.Error())
|
||||
}
|
||||
|
||||
buf := <-res
|
||||
if !bytes.Equal(buf, testHandler.needle) {
|
||||
t.Fatalf("expected %s, got %s", testHandler.needle, string(buf))
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user