diff --git a/README.md b/README.md index fc61117..edb76d5 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,34 @@ -# go-pipe [![GoDoc](https://pkg.go.dev/badge/github.com/github/docs)](https://pkg.go.dev/github.com/github/go-pipe) +# go-pipe [![GoDoc](https://pkg.go.dev/badge/github.com/github/go-pipe/v2)](https://pkg.go.dev/github.com/github/go-pipe/v2) A package used to easily build command pipelines in your Go applications # Important We have not thoroughly tested this package on OSs other than Linux, especially Windows. At this time, using this package on Windows based systems is considered experimental and will be supported only on a best effort basis. +# Migrating to v2 + +It's normal for pipelines to stop before all input has been consumed[^1]. If an earlier stage continues writing after that happens, the write side of the pipe can fail with `EPIPE`, `SIGPIPE`, or `io.ErrClosedPipe`. + +In go-pipe v1 it was possible to get away without handling this case, because a command stage's stdin was connected in a way that often (but not necessarily!) drained the write side and hid the error from the previous stage feeding it. That was an implementation detail, not a guarantee. In go-pipe v2, producer stages are more likely to be connected directly to a command's stdin, and thus see the error themselves. + +Fortunately, this is easily handled by wrapping the stage with `pipe.IgnoreError(stage, pipe.IsPipeError)`. If the producer only writes output and is otherwise stateless, that's the only thing needed. + +If the producer also updates state, metrics, cursors, or has other side effects, in a way that depends on how much of the output was produced, then in addition to using `pipe.IgnoreError`, you must also ensure producer-owned state is brought to a consistent point before returning the error. + +For example, if a stateful producer function must process its entire input for correctness regardless of whether it was read by the consumer, it should use a pattern like: + +```go +var writeErr error +for _, item := range items { + updateState(item) + if writeErr == nil { + _, writeErr = fmt.Fprintln(stdout, item) + } +} +return writeErr +``` + # Links -* [Docs](https://pkg.go.dev/github.com/github/go-pipe) +* [Docs](https://pkg.go.dev/github.com/github/go-pipe/v2) + +[^1]: In `cat foo | head | grep -q`, for example, either `head` or `grep` could exit before its input is fully consumed. diff --git a/go.mod b/go.mod index 6f69110..041809e 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/github/go-pipe +module github.com/github/go-pipe/v2 go 1.24.0 diff --git a/internal/ptree/ptree.go b/internal/ptree/ptree.go deleted file mode 100644 index 8ce22da..0000000 --- a/internal/ptree/ptree.go +++ /dev/null @@ -1,303 +0,0 @@ -// Package ptree contains utilities for dealing with Linux process trees. -package ptree - -import ( - "bytes" - "errors" - "io" - "os" - "strconv" - "strings" - "sync" -) - -const ( - // initialReadBufSize is the starting capacity of the buffer used for - // reading /proc files. /proc//status is typically ~1 KiB and the - // per-task "children" files are smaller, so 4 KiB covers the common - // case in a single read. - initialReadBufSize = 4 * 1024 -) - -var errNoRss = errors.New("RssAnon was not found") - -type ProcessTree struct { - path string -} - -func NewProcessTree(path string) ProcessTree { - return ProcessTree{ - path: path, - } -} - -// readBufPool reuses the byte buffer that holds the contents of a /proc file -// across calls to getProcessRSSAnon / walkChildrenFile, so that the -// per-poll work doesn't allocate (and then garbage-collect) a fresh buffer -// for every process in the tree. -var readBufPool = sync.Pool{ - New: func() any { - b := make([]byte, 0, initialReadBufSize) - return &b - }, -} - -// readProcFile reads all of path into a buffer borrowed from readBufPool. -// -// On success, the returned slice is only valid until bufPtr is returned to -// the pool, which the caller MUST do (typically with -// `defer readBufPool.Put(bufPtr)`, placed after the error check). -// -// On error, bufPtr is nil and the buffer has already been released, so the -// caller must not Put it back. -// -// Compared to os.ReadFile, this skips the (useless for /proc) Stat call -// used to pre-size the buffer and reuses the buffer across calls. It also -// returns the underlying *[]byte rather than a closure so the release path -// allocates nothing. -func readProcFile(path string) (data []byte, bufPtr *[]byte, err error) { - f, err := os.Open(path) - if err != nil { - return nil, nil, err - } - defer f.Close() - - bufPtr = readBufPool.Get().(*[]byte) - buf := (*bufPtr)[:0] - for { - if len(buf) == cap(buf) { - // Grow via append; this only allocates if the pooled - // buffer was too small. Subsequent calls will reuse - // the grown buffer because we store it back below. - buf = append(buf, 0)[:len(buf)] - } - n, rerr := f.Read(buf[len(buf):cap(buf)]) - buf = buf[:len(buf)+n] - if rerr == io.EOF { - break - } - if rerr != nil { - *bufPtr = buf - readBufPool.Put(bufPtr) - return nil, nil, rerr - } - if n == 0 { - // Defensive: io.Reader allows (0, nil) returns; treat - // as EOF rather than spinning. Real *os.File on /proc - // shouldn't hit this, but mocks or future runtime - // behavior might. - break - } - } - *bufPtr = buf - return buf, bufPtr, nil -} - -// Return the RSSAnon of a single process `pid`. -func (pt ProcessTree) GetProcessRSSAnon(pid int) (uint64, error) { - status := pt.path + "/" + strconv.Itoa(pid) + "/status" - data, bufPtr, err := readProcFile(status) - if os.IsNotExist(err) { - // process is already gone - return 0, nil - } - if err != nil { - return 0, err - } - defer readBufPool.Put(bufPtr) - - prefix := []byte("RssAnon:") - rest := data - for len(rest) > 0 { - var line []byte - if nl := bytes.IndexByte(rest, '\n'); nl >= 0 { - line, rest = rest[:nl], rest[nl+1:] - } else { - line, rest = rest, nil - } - // Fast prefix check before paying for the string conversion. - if !bytes.HasPrefix(line, prefix) { - continue - } - if rss, ok := ParseRSSAnon(string(line)); ok { - return rss, nil - } - } - return 0, errNoRss -} - -// Return the total RSS of the tree of processes rooted at `pid`. -// -// If the passed root pid is that of a kernel thread, as a special case, we -// return zero and no error. -// -// Errors encountered while walking the children are ignored, since it can -// change while traversing it. -func (pt ProcessTree) GetProcessTreeRSSAnon(pid int) (uint64, error) { - total, err := pt.GetProcessRSSAnon(pid) - if err != nil { - if err == errNoRss { - // these are typically kernel threads, which don't have an address space to measure - return 0, nil - } - return 0, err - } - - pt.WalkChildren(pid, func(pid int) { - mem, err := pt.GetProcessRSSAnon(pid) - if err != nil { - return - } - total += mem - }) - - return total, nil -} - -func (pt ProcessTree) WalkChildren(pid int, walkFn func(int)) { - pt.walkChildPids(pid, walkFn, map[int]bool{pid: true}) -} - -func (pt ProcessTree) walkChildPids(pid int, walkFn func(int), visited map[int]bool) { - // List the per-thread directories under /proc//task and read each - // task's "children" file directly. This avoids filepath.Glob, which - // would Stat every match on top of the readdir we already need. - taskDir := pt.path + "/" + strconv.Itoa(pid) + "/task" - entries, err := os.ReadDir(taskDir) - if err != nil { - return - } - - for _, entry := range entries { - // task/ should only contain numeric TID directories. Skip - // anything else defensively; this mirrors the implicit - // filtering that filepath.Glob("*/children") provided. - // A byte-range check avoids the error allocation that - // strconv.Atoi would incur for non-numeric names. - if !isAllDigits(entry.Name()) { - continue - } - pt.walkChildrenFile(taskDir+"/"+entry.Name()+"/children", walkFn, visited) - } -} - -func (pt ProcessTree) walkChildrenFile(filename string, walkFn func(int), visited map[int]bool) { - data, bufPtr, err := readProcFile(filename) - if err != nil { - return - } - defer readBufPool.Put(bufPtr) - - // children is a whitespace-separated list of decimal PIDs. Parse it in - // place to avoid the string(data) conversion and the []string allocated - // by strings.Fields. - i := 0 - for i < len(data) { - for i < len(data) && isASCIISpace(data[i]) { - i++ - } - if i >= len(data) { - return - } - pid := 0 - start := i - for i < len(data) && data[i] >= '0' && data[i] <= '9' { - pid = pid*10 + int(data[i]-'0') - i++ - } - if i == start { - // Not a digit; skip until next whitespace to stay in sync. - for i < len(data) && !isASCIISpace(data[i]) { - i++ - } - continue - } - if i-start > 10 { - // Realistic Linux PIDs fit in well under 10 digits - // (PID_MAX is 2^22). A longer digit run can't be a - // real PID and would risk silently overflowing the - // int accumulator, so skip it. - continue - } - if visited[pid] { - continue - } - walkFn(pid) - visited[pid] = true - pt.walkChildPids(pid, walkFn, visited) - } -} - -// parseRSSAnon parses an "RssAnon" line from /proc/*/status and returns the size. -// The entire line should be passed in, with or without the line ending. If the -// line looks like "RssAnon: 1234 kB", the byte size will be returned. If the -// line isn't parseable, (0, false) will be returned. -func ParseRSSAnon(s string) (uint64, bool) { - const prefix = "RssAnon:" - if !strings.HasPrefix(s, prefix) { - return 0, false - } - s = s[len(prefix):] - - // Optional whitespace before the number. - i := 0 - for i < len(s) && isASCIISpace(s[i]) { - i++ - } - - // One or more digits. - digitsStart := i - for i < len(s) && s[i] >= '0' && s[i] <= '9' { - i++ - } - if i == digitsStart { - return 0, false - } - kb, err := strconv.ParseUint(s[digitsStart:i], 10, 64) - if err != nil { - return 0, false - } - - // At least one whitespace between the number and "kB". - if i >= len(s) || !isASCIISpace(s[i]) { - return 0, false - } - for i < len(s) && isASCIISpace(s[i]) { - i++ - } - - // Literal "kB", then either end-of-string or whitespace. - const unit = "kB" - if !strings.HasPrefix(s[i:], unit) { - return 0, false - } - i += len(unit) - if i < len(s) && !isASCIISpace(s[i]) { - return 0, false - } - return kb * 1024, true -} - -// isASCIISpace matches the character class that Go's regexp engine uses for -// \s in non-Unicode mode: [\t\n\f\r ]. -func isASCIISpace(b byte) bool { - switch b { - case ' ', '\t', '\n', '\f', '\r': - return true - } - return false -} - -// isAllDigits reports whether s is non-empty and consists entirely of ASCII -// decimal digits. Used as a cheap allocation-free numeric-name filter. -func isAllDigits(s string) bool { - if len(s) == 0 { - return false - } - for i := 0; i < len(s); i++ { - if s[i] < '0' || s[i] > '9' { - return false - } - } - return true -} diff --git a/internal/ptree/ptree_linux.go b/internal/ptree/ptree_linux.go deleted file mode 100644 index 7019693..0000000 --- a/internal/ptree/ptree_linux.go +++ /dev/null @@ -1,21 +0,0 @@ -package ptree - -var DefaultProcessTree = ProcessTree{ - path: "/proc", -} - -// Walk the child processes of the specified root process. walkFn will be called -// for each child found. It will not be called for the root process. Any errors -// will be ignored, since they may be just a consequence of the process tree -// changing during traversal. -func WalkChildren(pid int, walkFn func(int)) { - DefaultProcessTree.WalkChildren(pid, walkFn) -} - -func GetProcessRSSAnon(pid int) (uint64, error) { - return DefaultProcessTree.GetProcessRSSAnon(pid) -} - -func GetProcessTreeRSSAnon(pid int) (uint64, error) { - return DefaultProcessTree.GetProcessTreeRSSAnon(pid) -} diff --git a/internal/ptree/ptree_test.go b/internal/ptree/ptree_test.go deleted file mode 100644 index 9c6d4e4..0000000 --- a/internal/ptree/ptree_test.go +++ /dev/null @@ -1,260 +0,0 @@ -package ptree_test - -import ( - "bytes" - "fmt" - "os" - "path/filepath" - "sort" - "strconv" - "testing" - - "github.com/github/go-pipe/internal/ptree" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// writeStatus creates only what GetProcessRSSAnon reads: //status. -// If rssKB is zero, the RssAnon line is omitted (mimicking kernel threads). -func writeStatus(t testing.TB, root string, pid int, rssKB uint64) { - t.Helper() - pidDir := filepath.Join(root, strconv.Itoa(pid)) - require.NoError(t, os.MkdirAll(pidDir, 0o755)) - var status string - if rssKB > 0 { - status = fmt.Sprintf("Name:\tfake\nRssAnon:\t%d kB\nVmSize:\t1000 kB\n", rssKB) - } else { - status = "Name:\tfake\nVmSize:\t1000 kB\n" - } - require.NoError(t, os.WriteFile(filepath.Join(pidDir, "status"), []byte(status), 0o600)) -} - -// writeChildren creates //task//children containing the -// space-separated child pids. Only call this for processes that actually -// have children; getProcessTreeRSSAnon copes fine with the task/ directory -// being absent for leaves. -func writeChildren(t testing.TB, root string, pid int, children []int) { - t.Helper() - writeThreadChildren(t, root, pid, pid, children) -} - -// writeThreadChildren creates //task//children for a specific -// thread id, so tests can exercise multi-threaded processes. -func writeThreadChildren(t testing.TB, root string, pid, tid int, children []int) { - t.Helper() - taskDir := filepath.Join(root, strconv.Itoa(pid), "task", strconv.Itoa(tid)) - require.NoError(t, os.MkdirAll(taskDir, 0o755)) - var buf bytes.Buffer - for _, c := range children { - fmt.Fprintf(&buf, "%d ", c) - } - require.NoError(t, os.WriteFile(filepath.Join(taskDir, "children"), buf.Bytes(), 0o600)) -} - -func TestGetProcessRSSAnon(t *testing.T) { - const kb = 1024 - root := t.TempDir() - writeStatus(t, root, 100, 15032) - writeStatus(t, root, 101, 0) // kernel-thread-like: no RssAnon line - - pt := ptree.NewProcessTree(root) - - t.Run("reads RssAnon", func(t *testing.T) { - rss, err := pt.GetProcessRSSAnon(100) - require.NoError(t, err) - assert.Equal(t, uint64(15032*kb), rss) - }) - - t.Run("missing RssAnon line returns an error", func(t *testing.T) { - _, err := pt.GetProcessRSSAnon(101) - assert.Error(t, err) - }) - - t.Run("missing pid returns (0, nil)", func(t *testing.T) { - // A process that has already exited disappears from /proc; the function - // treats that as a non-error zero. - rss, err := pt.GetProcessRSSAnon(999) - require.NoError(t, err) - assert.Equal(t, uint64(0), rss) - }) -} - -func TestGetProcessTreeRSSAnon(t *testing.T) { - const kb = 1024 - - t.Run("leaf process returns its own RssAnon", func(t *testing.T) { - root := t.TempDir() - writeStatus(t, root, 100, 1000) - - pt := ptree.NewProcessTree(root) - - total, err := pt.GetProcessTreeRSSAnon(100) - require.NoError(t, err) - assert.Equal(t, uint64(1000*kb), total) - }) - - t.Run("sums root and descendants", func(t *testing.T) { - // 100 -> {101 -> 103, 102} - root := t.TempDir() - writeStatus(t, root, 100, 1000) - writeStatus(t, root, 101, 200) - writeStatus(t, root, 102, 50) - writeStatus(t, root, 103, 7) - writeChildren(t, root, 100, []int{101, 102}) - writeChildren(t, root, 101, []int{103}) - - pt := ptree.NewProcessTree(root) - - total, err := pt.GetProcessTreeRSSAnon(100) - require.NoError(t, err) - assert.Equal(t, uint64((1000+200+50+7)*kb), total) - }) - - t.Run("kernel-thread root returns (0, nil)", func(t *testing.T) { - // Root has no RssAnon line; the function maps errNoRss to (0, nil). - root := t.TempDir() - writeStatus(t, root, 100, 0) - - pt := ptree.NewProcessTree(root) - - total, err := pt.GetProcessTreeRSSAnon(100) - require.NoError(t, err) - assert.Equal(t, uint64(0), total) - }) -} - -func TestWalkChildren(t *testing.T) { - t.Run("walks all descendants", func(t *testing.T) { - // 100 -> {101, 102 -> 103}. Verifies the callback fires for - // every descendant (not just direct children) and is not - // invoked for the root. - root := t.TempDir() - writeChildren(t, root, 100, []int{101, 102}) - writeChildren(t, root, 102, []int{103}) - - pt := ptree.NewProcessTree(root) - - var seen []int - pt.WalkChildren(100, func(pid int) { seen = append(seen, pid) }) - sort.Ints(seen) - assert.Equal(t, []int{101, 102, 103}, seen) - }) - - t.Run("iterates every thread under task/ and dedups", func(t *testing.T) { - // 100 has two threads (100 and 200); each thread reports a - // different set of children, with 102 listed by both threads - // to exercise the visited dedup. - root := t.TempDir() - writeThreadChildren(t, root, 100, 100, []int{101, 102}) - writeThreadChildren(t, root, 100, 200, []int{102, 103}) - - pt := ptree.NewProcessTree(root) - - var seen []int - pt.WalkChildren(100, func(pid int) { seen = append(seen, pid) }) - sort.Ints(seen) - assert.Equal(t, []int{101, 102, 103}, seen) - }) -} - -// BenchmarkGetProcessTreeRSSAnon measures the cost of a single poll over a -// small process tree (a root plus a few direct children). -func BenchmarkGetProcessTreeRSSAnon(b *testing.B) { - const rootPid = 100 - root := b.TempDir() - writeStatus(b, root, rootPid, 1000) - children := []int{101, 102, 103} - writeChildren(b, root, rootPid, children) - for _, c := range children { - writeStatus(b, root, c, 200) - } - - pt := ptree.NewProcessTree(root) - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, err := pt.GetProcessTreeRSSAnon(rootPid) - if err != nil { - b.Fatal(err) - } - } -} - -func TestParseRss(t *testing.T) { - const kb = 1024 - - okExamples := []struct { - input string - result uint64 - }{ - { - input: "RssAnon:\t 15032 kB", - result: 15032 * kb, - }, - { - input: "RssAnon:\t99915032 kB", - result: 99915032 * kb, - }, - { - input: "RssAnon:\t 1 kB", - result: kb, - }, - // Exactly what the kernel emits via SEQ_PUT_DEC: "RssAnon:\t" + - // 8-wide right-justified decimal + " kB\n". See fs/proc/task_mmu.c - // (task_mem). The trailing newline must be tolerated. - { - input: "RssAnon:\t 15032 kB\n", - result: 15032 * kb, - }, - // A value wider than the 8-char padding (no leading spaces). - { - input: "RssAnon:\t12345678 kB\n", - result: 12345678 * kb, - }, - { - input: "RssAnon:\t 0 kB\n", - result: 0, - }, - } - - for _, example := range okExamples { - rss, ok := ptree.ParseRSSAnon(example.input) - if assert.Truef(t, ok, "should be able to parse %q", example.input) { - assert.Equalf(t, example.result, rss, "value of %q", example.input) - } - } - - badExamples := []string{ - "", - "\n", - "RssAnon:\t 123", - "RssAnonn:\t 123 kB", - "RssAno:\t 123 kB", - "Blah:\t 123 kB", - "Blah:", - "123", - } - - for _, example := range badExamples { - _, ok := ptree.ParseRSSAnon(example) - assert.Falsef(t, ok, "should not be able to parse %q", example) - } -} - -func BenchmarkParseRss(b *testing.B) { - b.Run("match", func(b *testing.B) { - for i := 0; i < b.N; i++ { - rss, ok := ptree.ParseRSSAnon("RssAnon:\t 15032 kB") - require.True(b, ok) - require.EqualValues(b, 15032*1024, rss) - } - }) - - b.Run("no match", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _, ok := ptree.ParseRSSAnon("Other:\t 15032 kB") - require.False(b, ok) - } - }) -} diff --git a/pipe/close_responsibility_test.go b/pipe/close_responsibility_test.go new file mode 100644 index 0000000..20a5a28 --- /dev/null +++ b/pipe/close_responsibility_test.go @@ -0,0 +1,188 @@ +package pipe + +import ( + "context" + "io" + "os/exec" + "strings" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// readCloseSpy records whether Close was called. +type readCloseSpy struct { + io.Reader + closeCount atomic.Uint32 +} + +func (r *readCloseSpy) Close() error { + r.closeCount.Add(1) + return nil +} + +// writeCloseSpy records whether Close was called. +type writeCloseSpy struct { + io.Writer + closeCount atomic.Uint32 +} + +func (w *writeCloseSpy) Close() error { + w.closeCount.Add(1) + return nil +} + +// TestGoStageHonorsStreamOwnership verifies that a Function stage closes +// stdin/stdout iff the corresponding stream is closing. +func TestGoStageHonorsStreamOwnership(t *testing.T) { + cases := []struct { + name string + leaveIn, leaveOut bool + }{ + {"own both", false, false}, + {"leave stdin open", true, false}, + {"leave stdout open", false, true}, + {"leave both open", true, true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + in := &readCloseSpy{Reader: strings.NewReader("hi")} + out := &writeCloseSpy{Writer: io.Discard} + + s := Function("f", func(_ context.Context, _ Env, stdin io.Reader, stdout io.Writer) error { + _, err := io.Copy(stdout, stdin) + return err + }) + + require.NoError(t, s.Start( + context.Background(), StageOptions{}, + inputForTest(in, !tc.leaveIn), + outputForTest(out, !tc.leaveOut), + )) + require.NoError(t, s.Wait()) + + if tc.leaveIn { + assert.EqualValues(t, 0, in.closeCount.Load(), "closing stdin=%v", !tc.leaveIn) + } else { + assert.EqualValues(t, 1, in.closeCount.Load(), "closing stdin=%v", !tc.leaveIn) + } + if tc.leaveOut { + assert.EqualValues(t, 0, out.closeCount.Load(), "closing stdout=%v", !tc.leaveOut) + } else { + assert.EqualValues(t, 1, out.closeCount.Load(), "closing stdout=%v", !tc.leaveOut) + } + }) + } +} + +func TestStreamConstructorsPreserveOwnershipAndDynamicType(t *testing.T) { + borrowedReader := &readCloseSpy{Reader: strings.NewReader("borrowed")} + borrowedInput := Input(borrowedReader) + assert.Same(t, borrowedReader, borrowedInput.Reader()) + assert.NoError(t, borrowedInput.Close()) + assert.EqualValues(t, 0, borrowedReader.closeCount.Load()) + assert.NoError(t, borrowedInput.Close()) + assert.EqualValues(t, 0, borrowedReader.closeCount.Load()) + + ownedReader := &readCloseSpy{Reader: strings.NewReader("owned")} + ownedInput := ClosingInput(ownedReader) + assert.Same(t, ownedReader, ownedInput.Reader()) + assert.NoError(t, ownedInput.Close()) + assert.EqualValues(t, 1, ownedReader.closeCount.Load()) + assert.NoError(t, ownedInput.Close()) + assert.EqualValues(t, 1, ownedReader.closeCount.Load()) + + borrowedWriter := &writeCloseSpy{Writer: &strings.Builder{}} + borrowedOutput := Output(borrowedWriter) + assert.Same(t, borrowedWriter, borrowedOutput.Writer()) + assert.NoError(t, borrowedOutput.Close()) + assert.EqualValues(t, 0, borrowedWriter.closeCount.Load()) + assert.NoError(t, borrowedOutput.Close()) + assert.EqualValues(t, 0, borrowedWriter.closeCount.Load()) + + ownedWriter := &writeCloseSpy{Writer: &writeCloseSpy{Writer: io.Discard}} + ownedOutput := ClosingOutput(ownedWriter) + assert.Same(t, ownedWriter, ownedOutput.Writer()) + assert.NoError(t, ownedOutput.Close()) + assert.EqualValues(t, 1, ownedWriter.closeCount.Load()) + assert.NoError(t, ownedOutput.Close()) + assert.EqualValues(t, 1, ownedWriter.closeCount.Load()) +} + +// TestCommandStageHonorsCloseStdin verifies that a command stage closes a +// non-file stdin (a "late" closer) iff the input stream is closing. An +// empty reader is used so exec.Cmd's input-copy goroutine sees EOF promptly. +func TestCommandStageHonorsCloseStdin(t *testing.T) { + for _, leave := range []bool{false, true} { + name := "owns stdin" + if leave { + name = "leaves stdin open" + } + t.Run(name, func(t *testing.T) { + in := &readCloseSpy{Reader: strings.NewReader("")} + + cmd := exec.Command("true") + s := CommandStage("true", cmd).(*commandStage) + + require.NoError(t, s.Start( + context.Background(), StageOptions{}, + inputForTest(in, !leave), + Output(nil), + )) + require.NoError(t, s.Wait()) + + if leave { + assert.EqualValues(t, 0, in.closeCount.Load(), "closing stdin=%v", !leave) + } else { + assert.EqualValues(t, 1, in.closeCount.Load(), "closing stdin=%v", !leave) + } + }) + } +} + +// TestCommandStageHonorsCloseStdout verifies the stdout counterpart: a +// non-file stdout (routed through the pooled-copy path) is closed iff +// the output stream is closing. +func TestCommandStageHonorsCloseStdout(t *testing.T) { + for _, leave := range []bool{false, true} { + name := "owns stdout" + if leave { + name = "leaves stdout open" + } + t.Run(name, func(t *testing.T) { + out := &writeCloseSpy{Writer: io.Discard} + + cmd := exec.Command("true") + s := CommandStage("true", cmd).(*commandStage) + + require.NoError(t, s.Start( + context.Background(), StageOptions{}, + Input(nil), + outputForTest(out, !leave), + )) + require.NoError(t, s.Wait()) + + if leave { + assert.EqualValues(t, 0, out.closeCount.Load(), "closing stdout=%v", !leave) + } else { + assert.EqualValues(t, 1, out.closeCount.Load(), "closing stdout=%v", !leave) + } + }) + } +} + +func inputForTest(r io.ReadCloser, closing bool) *InputStream { + if closing { + return ClosingInput(r) + } + return Input(r) +} + +func outputForTest(w io.WriteCloser, closing bool) *OutputStream { + if closing { + return ClosingOutput(w) + } + return Output(w) +} diff --git a/pipe/command.go b/pipe/command.go index 2113902..27d34c6 100644 --- a/pipe/command.go +++ b/pipe/command.go @@ -15,14 +15,16 @@ import ( "golang.org/x/sync/errgroup" ) -var errProcessInfoMissing = errors.New("cmd.Process is nil") - // commandStage is a pipeline `Stage` based on running an external // command and piping the data through its stdin and stdout. type commandStage struct { - name string - stdin io.Closer - cmd *exec.Cmd + name string + cmd *exec.Cmd + + // lateClosers is a list of things that have to be closed once the + // command has finished. + lateClosers []io.Closer + done chan struct{} wg errgroup.Group stderr bytes.Buffer @@ -32,6 +34,18 @@ type commandStage struct { ctxErr atomic.Value } +var ( + _ Stage = (*commandStage)(nil) + _ processProvider = (*commandStage)(nil) +) + +// processProvider is the hook external memory-watchers use to find the running +// process so they can sample its RSS and kill it if necessary. +type processProvider interface { + Process() *os.Process + Kill(error) +} + // Command returns a pipeline `Stage` based on the specified external // `command`, run with the given command-line `args`. Its stdin and // stdout are handled as usual, and its stderr is collected and @@ -61,33 +75,107 @@ func (s *commandStage) Name() string { return s.name } +func (s *commandStage) Process() *os.Process { + return s.cmd.Process +} + +func (s *commandStage) Requirements() StageRequirements { + return StageRequirements{ + Stdin: StreamPreferFile, + Stdout: StreamPreferFile, + } +} + func (s *commandStage) Start( - ctx context.Context, env Env, stdin io.ReadCloser, -) (io.ReadCloser, error) { + ctx context.Context, opts StageOptions, + stdin *InputStream, stdout *OutputStream, +) error { + r := stdin.Reader() + w := stdout.Writer() + if s.cmd.Dir == "" { - s.cmd.Dir = env.Dir + s.cmd.Dir = opts.Dir } - s.setupEnv(ctx, env) + s.setupEnv(ctx, opts.Env) + + // It is important that the streams that are used by a command be + // closed at the right time. When that is depends on the type of + // the stream. + // + // A subprocess ultimately needs its own copies of `*os.File` file + // descriptors for its stdin and stdout. The external command will + // "always" close those when it exits. + // + // (It's theoretically possible for a command to pass the open + // file descriptor to another, longer-lived process, in which case + // the file descriptor wouldn't necessarily get closed even when + // the command finishes. But that's ill-behaved in a command that + // is being used in a pipeline, so we'll ignore that possibility.) + // + // If a stream provided for use as stdin/stdout is an `*os.File`, + // then we set the corresponding field of `exec.Cmd` to that + // argument. This causes `exec.Cmd` to duplicate that file + // descriptor and passes the dup to the subprocess. Therefore, we + // want to close our own copy "early", namely as soon as the + // external command has started, because the external command will + // keep its own copy open as long as necessary (and no longer!). + // + // If a stdin/stdout stream is _not_ an `*os.File`, then + // `exec.Cmd` will take care of creating an `os.Pipe()`, copying + // from the provided stream into/out of the pipe, and eventually + // close both ends of the pipe. In that case, we must close the + // provided stream "late", namely only after the external command + // and the copy have finished. + + // Things that have to be closed as soon as the command has started: + var earlyClosers []io.Closer + + // See the type comment for `Stage` for the explanation of this closing behavior. + if r != nil { + s.cmd.Stdin = r + } - if stdin != nil { - // See the long comment in `Pipeline.Start()` for the - // explanation of this special case. - switch stdin := stdin.(type) { - case nopCloser: - s.cmd.Stdin = stdin.Reader - case nopCloserWriterTo: - s.cmd.Stdin = stdin.Reader - default: - s.cmd.Stdin = stdin + if _, ok := r.(*os.File); ok { + // We can close our copy as soon as the command has started + earlyClosers = append(earlyClosers, stdin) + } else { + // We need to close `stdin`, but only after the command has finished + s.lateClosers = append(s.lateClosers, stdin) + } + + closeEarlyClosers := func() { + for _, closer := range earlyClosers { + _ = closer.Close() } - // Also keep a copy so that we can close it when the command exits: - s.stdin = stdin } - stdout, err := s.cmd.StdoutPipe() - if err != nil { - return nil, err + // On error, Close any pipes we created and wait for the goroutines to + // exit before propagating the error. + cleanupOnStartFailure := func() { + closeEarlyClosers() + _ = s.wg.Wait() + _ = s.closeLateClosers() + } + + if w != nil { + if f, ok := w.(*os.File); ok { + s.cmd.Stdout = f + earlyClosers = append(earlyClosers, stdout) + } else { + s.lateClosers = append(s.lateClosers, stdout) + // Route the copy through our own pipe so we can use a + // pooled buffer rather than letting exec.Cmd allocate a + // fresh 32KB buffer for its internal io.Copy. + ec, err := s.setupPooledStdout(w) + if err != nil { + cleanupOnStartFailure() + return err + } + earlyClosers = append(earlyClosers, ec) + } + } else { + s.lateClosers = append(s.lateClosers, stdout) } // If the caller hasn't arranged otherwise, read the command's @@ -99,7 +187,8 @@ func (s *commandStage) Start( // can be sure. p, err := s.cmd.StderrPipe() if err != nil { - return nil, err + cleanupOnStartFailure() + return err } s.wg.Go(func() error { _, err := io.Copy(&s.stderr, p) @@ -116,9 +205,12 @@ func (s *commandStage) Start( s.runInOwnProcessGroup() if err := s.cmd.Start(); err != nil { - return nil, err + cleanupOnStartFailure() + return err } + closeEarlyClosers() + // Arrange for the process to be killed (gently) if the context // expires before the command exits normally: go func() { @@ -130,7 +222,7 @@ func (s *commandStage) Start( } }() - return stdout, nil + return nil } // setupEnv sets or modifies the environment that will be passed to @@ -219,21 +311,55 @@ func (s *commandStage) Wait() error { // Make sure that any stderr is copied before `s.cmd.Wait()` // closes the read end of the pipe: - wErr := s.wg.Wait() + wgErr := s.wg.Wait() err := s.cmd.Wait() err = s.filterCmdError(err) - if err == nil && wErr != nil { - err = wErr + if err == nil && wgErr != nil { + err = wgErr } - if s.stdin != nil { - cErr := s.stdin.Close() - if cErr != nil && err == nil { - return cErr - } + if closeErr := s.closeLateClosers(); err == nil { + err = closeErr } return err } + +func (s *commandStage) closeLateClosers() error { + var err error + for _, closer := range s.lateClosers { + if closeErr := closer.Close(); closeErr != nil && err == nil { + err = closeErr + } + } + s.lateClosers = nil + return err +} + +// setupPooledStdout creates an `os.Pipe()`, sets it as `cmd.Stdout`, +// and starts a goroutine that copies from the read end to `dst` using +// a pooled buffer (or `dst.ReadFrom` when `dst` implements it). The +// returned closer is the write end of the pipe; the caller must add +// it to `earlyClosers` so it is closed once the command has started. +// +// The buffer-pool optimization works for command stages whose stdout is +// not an `*os.File`. Without it, `exec.Cmd` would set up its own pipe +// and run `io.Copy` with a freshly allocated 32KB buffer per invocation. +func (s *commandStage) setupPooledStdout(dst io.Writer) (io.Closer, error) { + pr, pw, err := os.Pipe() + if err != nil { + return nil, err + } + s.cmd.Stdout = pw + s.wg.Go(func() error { + defer pr.Close() + _, err := pooledCopy(dst, pr) + if err != nil && !errors.Is(err, os.ErrClosed) { + return err + } + return nil + }) + return pw, nil +} diff --git a/pipe/command_linux.go b/pipe/command_linux.go deleted file mode 100644 index 997d2cd..0000000 --- a/pipe/command_linux.go +++ /dev/null @@ -1,20 +0,0 @@ -//go:build linux - -package pipe - -import ( - "context" - - "github.com/github/go-pipe/internal/ptree" -) - -// On linux, we can limit or observe memory usage in command stages. -var _ LimitableStage = (*commandStage)(nil) - -func (s *commandStage) GetRSSAnon(_ context.Context) (uint64, error) { - if s.cmd.Process == nil { - return 0, errProcessInfoMissing - } - - return ptree.GetProcessTreeRSSAnon(s.cmd.Process.Pid) -} diff --git a/pipe/command_nil_panic_test.go b/pipe/command_nil_panic_test.go index 740af73..1d03a12 100644 --- a/pipe/command_nil_panic_test.go +++ b/pipe/command_nil_panic_test.go @@ -7,6 +7,8 @@ import ( "context" "os/exec" "testing" + + "github.com/stretchr/testify/assert" ) func TestKillWithNilProcess(t *testing.T) { @@ -17,31 +19,7 @@ func TestKillWithNilProcess(t *testing.T) { done: make(chan struct{}), } - defer func() { - if r := recover(); r != nil { - t.Fatalf("PANIC OCCURRED (bug not fixed): %v", r) - } - }() - - stage.Kill(context.Canceled) - - t.Log("SUCCESS: Kill() handled nil Process gracefully without panicking") -} - -func TestKillWithFailedStart(t *testing.T) { - ctx := context.Background() - - stage := Command("/this/path/does/not/exist/invalid_command_12345") - - _, err := stage.Start(ctx, Env{}, nil) - if err == nil { - t.Fatal("Expected start to fail, but it succeeded") - } - - // At this point, if someone calls Kill (perhaps from a memory monitor - // or other external component), it could panic if Process is nil - - // Note: In the current implementation, Kill won't be called from the - // context goroutine because Start failed. But external callers like - // MemoryLimit could potentially call Kill on a failed stage. + assert.NotPanics(t, func() { + stage.Kill(context.Canceled) + }) } diff --git a/pipe/command_starterror_test.go b/pipe/command_starterror_test.go new file mode 100644 index 0000000..5a25cbf --- /dev/null +++ b/pipe/command_starterror_test.go @@ -0,0 +1,59 @@ +package pipe_test + +import ( + "bytes" + "context" + "os/exec" + "sync/atomic" + "testing" + + "github.com/github/go-pipe/v2/pipe" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestCommandStageStartFailureNoRace verifies that when `cmd.Start()` +// fails (e.g. command not found), the goroutine that +// `setupPooledStdout` spawned does not leak past `Pipeline.Run()`. +// `bytes.Buffer.ReadFrom` writes to the buffer's slice header via +// `grow()` before its first `Read()`, so a leaked goroutine races +// with the caller's access to the destination buffer once Run +// returns the error. Run a tight loop so `-race` is likely to catch +// any regression. +func TestCommandStageStartFailureNoRace(t *testing.T) { + for i := 0; i < 50; i++ { + var buf bytes.Buffer + p := pipe.New(pipe.WithStdout(&buf)) + p.Add(pipe.CommandStage("nope", exec.Command("this-binary-does-not-exist-xyz123"))) + require.Error(t, p.Run(context.Background())) + _ = buf.String() + } +} + +// trackingWriteCloser is a non-`*os.File` `io.WriteCloser` that records +// whether it has been closed. Because it isn't an `*os.File`, a command +// stage routes it through `setupPooledStdout` and closes it as a "late +// closer" (i.e. only after the command finishes / cleanup runs). +type trackingWriteCloser struct { + closed atomic.Bool +} + +func (w *trackingWriteCloser) Write(p []byte) (int, error) { return len(p), nil } + +func (w *trackingWriteCloser) Close() error { + w.closed.Store(true) + return nil +} + +// TestCommandStageStartFailureClosesLateClosers verifies that a +// `WithStdoutCloser` on the last stage is closed even when `cmd.Start()` +// fails. The closer is registered as a "late closer," which is normally +// drained by `Wait()`; since `Wait()` never runs when `Start()` fails, +// the start-failure cleanup path must close it instead. +func TestCommandStageStartFailureClosesLateClosers(t *testing.T) { + w := &trackingWriteCloser{} + p := pipe.New(pipe.WithStdoutCloser(w)) + p.Add(pipe.CommandStage("nope", exec.Command("this-binary-does-not-exist-xyz123"))) + require.Error(t, p.Run(context.Background())) + assert.True(t, w.closed.Load(), "expected late closer to be closed after Start() failure") +} diff --git a/pipe/command_stdout_fastpath_test.go b/pipe/command_stdout_fastpath_test.go new file mode 100644 index 0000000..0de3f51 --- /dev/null +++ b/pipe/command_stdout_fastpath_test.go @@ -0,0 +1,116 @@ +package pipe + +import ( + "context" + "os" + "os/exec" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestCommandStageStdoutFastPath asserts that when a commandStage's stdout is +// an `*os.File`, the file is set as `cmd.Stdout` so that `exec.Cmd` dup's the +// fd into the child process directly. This is one of the optimizations enabled +// by the Stage interface redesign in #21: the subprocess writes straight to +// the caller's destination fd with no Go-side copy stage in between, and the +// subprocess can detect when that fd is closed. +func TestCommandStageStdoutFastPath(t *testing.T) { + cases := []struct { + name string + closingStdout bool + }{ + { + name: "raw *os.File with closing stdout", + closingStdout: true, + }, + { + name: "raw *os.File with non-closing stdout", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + f, err := os.CreateTemp(t.TempDir(), "stdout") + require.NoError(t, err) + t.Cleanup(func() { _ = f.Close() }) + + cmd := exec.Command("true") + s := CommandStage("true", cmd).(*commandStage) + + var stdout *OutputStream + if tc.closingStdout { + stdout = ClosingOutput(f) + } else { + stdout = Output(f) + } + + require.NoError(t, s.Start(ctx, StageOptions{}, Input(nil), stdout)) + t.Cleanup(func() { _ = s.Wait() }) + + gotFile, ok := s.cmd.Stdout.(*os.File) + require.Truef(t, ok, "expected cmd.Stdout to be *os.File, got %T", s.cmd.Stdout) + assert.Samef( + t, f, gotFile, + "expected cmd.Stdout to be the user-provided *os.File (fd %d), "+ + "got a different *os.File (fd %d). The fd-pass fast path is broken; "+ + "sendfile/zero-copy will not apply.", + f.Fd(), gotFile.Fd(), + ) + }) + } +} + +// TestCommandStageStdoutFastPathThroughPipeline is the same assertion +// but driven end-to-end through `Pipeline.Start()`, so it also +// exercises the `Pipeline.stdout` plumbing that hands the writer to +// the last stage. +func TestCommandStageStdoutFastPathThroughPipeline(t *testing.T) { + cases := []struct { + name string + option func(*os.File) Option + }{ + { + name: "WithStdoutCloser(*os.File)", + option: func(f *os.File) Option { return WithStdoutCloser(f) }, + }, + { + name: "WithStdout(*os.File)", + option: func(f *os.File) Option { return WithStdout(f) }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + f, err := os.CreateTemp(t.TempDir(), "stdout") + require.NoError(t, err) + t.Cleanup(func() { _ = f.Close() }) + + cmd := exec.Command("true") + s := CommandStage("true", cmd).(*commandStage) + + p := New(tc.option(f)) + p.Add(s) + require.NoError(t, p.Start(ctx)) + stdoutAfterStart := s.cmd.Stdout + t.Cleanup(func() { _ = p.Wait() }) + + gotFile, ok := stdoutAfterStart.(*os.File) + require.Truef(t, ok, "expected cmd.Stdout to be *os.File, got %T", stdoutAfterStart) + assert.Samef( + t, f, gotFile, + "expected cmd.Stdout to be the user-provided *os.File (fd %d), "+ + "got a different *os.File (fd %d). The fd-pass fast path is broken; "+ + "sendfile/zero-copy will not apply.", + f.Fd(), gotFile.Fd(), + ) + }) + } +} diff --git a/pipe/command_test.go b/pipe/command_test.go index ca5a8c0..531f11f 100644 --- a/pipe/command_test.go +++ b/pipe/command_test.go @@ -78,7 +78,8 @@ func TestCopyEnvWithOverride(t *testing.T) { for _, ex := range examples { t.Run(ex.label, func(t *testing.T) { assert.ElementsMatch(t, ex.expectedResult, - copyEnvWithOverrides(ex.env, ex.overrides)) + copyEnvWithOverrides(ex.env, ex.overrides), + ) }) } } diff --git a/pipe/copy_pool.go b/pipe/copy_pool.go new file mode 100644 index 0000000..69f8a21 --- /dev/null +++ b/pipe/copy_pool.go @@ -0,0 +1,40 @@ +package pipe + +import ( + "io" + "sync" +) + +// copyBufPool reuses 32KB buffers across `io.CopyBuffer` calls, +// avoiding a fresh heap allocation per copy. This matters in +// high-throughput pipelines where many command stages run +// concurrently and stdout is not an `*os.File` that can be passed +// directly through `exec.Cmd`. +var copyBufPool = sync.Pool{ + New: func() any { + b := make([]byte, 32*1024) + return &b + }, +} + +// readerOnly wraps an `io.Reader`, hiding any other interfaces (such +// as `io.WriterTo`) so that `io.CopyBuffer` is forced to use the +// provided buffer. Without this, `*os.File`'s `WriterTo` (added in +// Go 1.26) causes `CopyBuffer` to call `File.WriteTo`, which can +// fall back to `io.Copy` with a fresh allocation, bypassing the pool +// entirely. +type readerOnly struct{ io.Reader } + +// pooledCopy copies from `src` to `dst`. If `dst` implements +// `io.ReaderFrom` (e.g. `*net.TCPConn`, `*os.File`), it delegates to +// `ReadFrom` so platform fast paths like splice can be used. +// Otherwise it falls back to `io.CopyBuffer` with a pooled 32KB +// buffer. +func pooledCopy(dst io.Writer, src io.Reader) (int64, error) { + if rf, ok := dst.(io.ReaderFrom); ok { + return rf.ReadFrom(src) + } + bp := copyBufPool.Get().(*[]byte) + defer copyBufPool.Put(bp) + return io.CopyBuffer(dst, readerOnly{src}, *bp) +} diff --git a/pipe/env_stage.go b/pipe/env_stage.go new file mode 100644 index 0000000..3db2296 --- /dev/null +++ b/pipe/env_stage.go @@ -0,0 +1,50 @@ +package pipe + +import "context" + +// WithExtraEnv returns a Stage that adds env to the environment seen by inner. +func WithExtraEnv(inner Stage, env []EnvVar) Stage { + stage := &stageWithExtraEnv{ + inner: inner, + env: env, + } + if processProvider, ok := inner.(processProvider); ok { + return &processStageWithExtraEnv{ + stageWithExtraEnv: stage, + processProvider: processProvider, + } + } + return stage +} + +type processStageWithExtraEnv struct { + processProvider + *stageWithExtraEnv +} + +type stageWithExtraEnv struct { + inner Stage + env []EnvVar +} + +func (s *stageWithExtraEnv) Name() string { + return s.inner.Name() + " (with extra env vars)" +} + +func (s *stageWithExtraEnv) Requirements() StageRequirements { + return s.inner.Requirements() +} + +func (s *stageWithExtraEnv) Start( + ctx context.Context, opts StageOptions, + stdin *InputStream, stdout *OutputStream, +) error { + opts.Vars = append(opts.Vars[:len(opts.Vars):len(opts.Vars)], func(_ context.Context, vars []EnvVar) []EnvVar { + return append(vars, s.env...) + }) + return s.inner.Start(ctx, opts, stdin, stdout) +} + +func (s *stageWithExtraEnv) Wait() error { + return s.inner.Wait() +} diff --git a/pipe/env_stage_test.go b/pipe/env_stage_test.go new file mode 100644 index 0000000..e33c6b7 --- /dev/null +++ b/pipe/env_stage_test.go @@ -0,0 +1,196 @@ +package pipe + +import ( + "bytes" + "context" + "io" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func collectEnvVars(ctx context.Context, env Env) []EnvVar { + var vars []EnvVar + for _, fn := range env.Vars { + vars = fn(ctx, vars) + } + return vars +} + +func lastEnvValue(ctx context.Context, env Env, key string) (string, bool) { + var ( + value string + ok bool + ) + for _, envVar := range collectEnvVars(ctx, env) { + if envVar.Key == key { + value = envVar.Value + ok = true + } + } + return value, ok +} + +func TestWithExtraEnvAddsStageLocalVars(t *testing.T) { + t.Parallel() + + ctx := context.Background() + var firstStageVars []EnvVar + var secondStageVars []EnvVar + + p := New(WithEnvVar("PIPELINE", "present")) + p.Add( + WithExtraEnv( + Function("first", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + firstStageVars = collectEnvVars(ctx, env) + return nil + }), + []EnvVar{{Key: "STAGE", Value: "first"}}, + ), + Function("second", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + secondStageVars = collectEnvVars(ctx, env) + return nil + }), + ) + + require.NoError(t, p.Run(ctx)) + assert.Equal(t, []EnvVar{{Key: "PIPELINE", Value: "present"}, {Key: "STAGE", Value: "first"}}, firstStageVars) + assert.Equal(t, []EnvVar{{Key: "PIPELINE", Value: "present"}}, secondStageVars) +} + +func TestWithExtraEnvStageLocalVarsOverridePipelineVars(t *testing.T) { + t.Parallel() + + ctx := context.Background() + var firstStageValue string + var secondStageValue string + + p := New(WithEnvVar("STAGE", "pipeline")) + p.Add( + WithExtraEnv( + Function("first", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + var ok bool + firstStageValue, ok = lastEnvValue(ctx, env, "STAGE") + require.True(t, ok) + return nil + }), + []EnvVar{{Key: "STAGE", Value: "stage-local"}}, + ), + Function("second", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + var ok bool + secondStageValue, ok = lastEnvValue(ctx, env, "STAGE") + require.True(t, ok) + return nil + }), + ) + + require.NoError(t, p.Run(ctx)) + assert.Equal(t, "stage-local", firstStageValue) + assert.Equal(t, "pipeline", secondStageValue) +} + +func TestWithExtraEnvDoesNotShareVarsBackingArray(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + allowFirstStage := make(chan struct{}) + var firstStageVars []EnvVar + var secondStageVars []EnvVar + + baseVars := make([]AppendVars, 0, 4) + for _, env := range []EnvVar{ + {Key: "PIPELINE1", Value: "present"}, + {Key: "PIPELINE2", Value: "present"}, + {Key: "PIPELINE3", Value: "present"}, + } { + baseVars = append(baseVars, func(_ context.Context, vars []EnvVar) []EnvVar { + return append(vars, env) + }) + } + + p := New(func(p *Pipeline) { + p.env.Vars = baseVars + }) + p.Add( + WithExtraEnv( + Function("first", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + select { + case <-allowFirstStage: + case <-ctx.Done(): + return ctx.Err() + } + firstStageVars = collectEnvVars(ctx, env) + return nil + }), + []EnvVar{{Key: "STAGE", Value: "first"}}, + ), + WithExtraEnv( + Function("second", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + secondStageVars = collectEnvVars(ctx, env) + close(allowFirstStage) + return nil + }), + []EnvVar{{Key: "STAGE", Value: "second"}}, + ), + ) + + require.NoError(t, p.Run(ctx)) + + wantBase := []EnvVar{ + {Key: "PIPELINE1", Value: "present"}, + {Key: "PIPELINE2", Value: "present"}, + {Key: "PIPELINE3", Value: "present"}, + } + assert.Equal(t, append(append([]EnvVar(nil), wantBase...), EnvVar{Key: "STAGE", Value: "first"}), firstStageVars) + assert.Equal(t, append(append([]EnvVar(nil), wantBase...), EnvVar{Key: "STAGE", Value: "second"}), secondStageVars) +} + +func TestWithExtraEnvAddsCommandEnv(t *testing.T) { + t.Parallel() + + ctx := context.Background() + stdout := &bytes.Buffer{} + + p := New(WithStdout(stdout)) + p.Add(WithExtraEnv( + Command("sh", "-c", "printf %s \"$STAGE\""), + []EnvVar{{Key: "STAGE", Value: "command"}}, + )) + + require.NoError(t, p.Run(ctx)) + assert.Equal(t, "command", stdout.String()) +} + +func TestWithExtraEnvPreservesProcessHooks(t *testing.T) { + t.Parallel() + + stage := WithExtraEnv(Command("true"), nil) + assert.Implements(t, (*processProvider)(nil), stage) +} + +func TestWithExtraEnvDoesNotAddProcessHooks(t *testing.T) { + t.Parallel() + + inner := Function("inner", func(context.Context, Env, io.Reader, io.Writer) error { + return nil + }) + + stage := WithExtraEnv(inner, nil) + assert.NotImplements(t, (*processProvider)(nil), stage) +} + +func TestWithExtraEnvPreservesStageMetadata(t *testing.T) { + t.Parallel() + + inner := Function("inner", func(context.Context, Env, io.Reader, io.Writer) error { + return nil + }, ForbidStdin(), ForbidStdout()) + + stage := WithExtraEnv(inner, nil) + assert.Equal(t, "inner (with extra env vars)", stage.Name()) + assert.Equal(t, inner.Requirements(), stage.Requirements()) +} diff --git a/pipe/filter-error.go b/pipe/filter-error.go index 654796a..9bdee27 100644 --- a/pipe/filter-error.go +++ b/pipe/filter-error.go @@ -48,6 +48,12 @@ type ErrorMatcher func(err error) bool // the functions from the standard library that has the same signature // (e.g., `os.IsTimeout`), or some combination of these (e.g., // `AnyError(IsSIGPIPE, os.IsTimeout)`). +// +// `IgnoreError` only suppresses the error returned by the wrapped +// stage. If a producer ignores pipe errors because a later stage can +// stop reading early, the producer is still responsible for keeping any +// producer-owned state, metrics, cursors, or other side effects +// consistent before returning the ignored error. func IgnoreError(s Stage, em ErrorMatcher) Stage { return FilterError(s, func(err error) error { @@ -128,7 +134,11 @@ var ( // IsPipeError is an `ErrorMatcher` that matches a few different // errors that typically result if a stage writes to a subsequent - // stage that has stopped reading from its stdin. Use like + // stage that has stopped reading from its stdin. This is commonly + // useful with `IgnoreError` for stateless producer stages whose only + // job is writing output. Stateful producers should continue any + // producer-owned state updates needed for consistency before + // returning the pipe error for `IgnoreError` to suppress. Use like // // p.Add(IgnoreError(someStage, IsPipeError)) IsPipeError = AnyError(IsSIGPIPE, IsEPIPE, IsErrClosedPipe) diff --git a/pipe/function.go b/pipe/function.go index e8d9522..a80947f 100644 --- a/pipe/function.go +++ b/pipe/function.go @@ -4,29 +4,65 @@ import ( "context" "fmt" "io" + "strings" ) // StageFunc is a function that can be used to power a `goStage`. It // should read its input from `stdin` and write its output to -// `stdout`. `stdin` and `stdout` will be closed automatically (if -// necessary) once the function returns. +// `stdout`. The Function stage closes `stdin` and `stdout` after the +// function returns only when the pipeline gave the stage ownership of +// those streams; StageFunc implementations should not close them +// directly. // // Neither `stdin` nor `stdout` are necessarily buffered. If the // `StageFunc` requires buffering, it needs to arrange that itself. // +// A later stage can stop reading before this function has written all +// of its output. In that case, writes to `stdout` can fail with an +// error matched by `IsPipeError`. If the function only writes output +// and is otherwise stateless, callers can usually wrap the stage with +// `IgnoreError(stage, IsPipeError)`. If the function also updates +// producer-owned state, metrics, cursors, or other side effects that +// depend on how much output was produced, it should bring those side +// effects to a consistent point before returning the write error. +// // A `StageFunc` is run in a separate goroutine, so it must be careful // to synchronize any data access aside from reading and writing. type StageFunc func(ctx context.Context, env Env, stdin io.Reader, stdout io.Writer) error +// FunctionOption configures a Function stage. +type FunctionOption func(*goStage) + +// ForbidStdin returns a FunctionOption declaring that the stage must not be +// connected to stdin. +func ForbidStdin() FunctionOption { + return func(s *goStage) { + s.requirements.Stdin = StreamForbidden + } +} + +// ForbidStdout returns a FunctionOption declaring that the stage must not be +// connected to stdout. +func ForbidStdout() FunctionOption { + return func(s *goStage) { + s.requirements.Stdout = StreamForbidden + } +} + // Function returns a pipeline `Stage` that will run a `StageFunc` in // a separate goroutine to process the data. See `StageFunc` for more // information. -func Function(name string, f StageFunc) Stage { - return &goStage{ - name: name, - f: f, - done: make(chan struct{}), +func Function(name string, f StageFunc, opts ...FunctionOption) Stage { + stage := &goStage{ + name: name, + f: f, + done: make(chan struct{}), + requirements: StageRequirements{}, + } + for _, opt := range opts { + opt(stage) } + return stage } // goStage is a `Stage` that does its work by running an arbitrary @@ -35,54 +71,58 @@ type goStage struct { name string f StageFunc done chan struct{} + requirements StageRequirements err error - panicHandler StagePanicHandler } +var _ Stage = (*goStage)(nil) + func (s *goStage) Name() string { return s.name } -func (s *goStage) SetPanicHandler(ph StagePanicHandler) { - s.panicHandler = ph +func (s *goStage) Requirements() StageRequirements { + return s.requirements } -func (s *goStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) { - r, w := io.Pipe() +func (s *goStage) Start( + ctx context.Context, opts StageOptions, + stdin *InputStream, stdout *OutputStream, +) error { + r := stdin.Reader() + if r == nil { + // treat nil as empty input. + r = strings.NewReader("") + } + + w := stdout.Writer() + if w == nil { + // treat nil output as /dev/null + w = io.Discard + } go func() { defer func() { - // Cleanup resources on exit - if err := w.Close(); err != nil && s.err == nil { - s.err = fmt.Errorf("error closing output pipe for stage %q: %w", s.Name(), err) - } - if stdin != nil { - if err := stdin.Close(); err != nil && s.err == nil { - s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err) + if opts.PanicHandler != nil { + if p := recover(); p != nil { + s.err = opts.PanicHandler(p) } } + if err := stdout.Close(); err != nil && s.err == nil { + s.err = fmt.Errorf("error closing stdout for stage %q: %w", s.Name(), err) + } + if err := stdin.Close(); err != nil && s.err == nil { + s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err) + } close(s.done) }() - - defer s.recoverPanic() - - s.err = s.f(ctx, env, stdin, w) + s.err = s.f(ctx, opts.Env, r, w) }() - return r, nil + return nil } func (s *goStage) Wait() error { <-s.done return s.err } - -func (s *goStage) recoverPanic() { - if s.panicHandler == nil { - return - } - - if p := recover(); p != nil { - s.err = s.panicHandler(p) - } -} diff --git a/pipe/function_panic_test.go b/pipe/function_panic_test.go new file mode 100644 index 0000000..db6126e --- /dev/null +++ b/pipe/function_panic_test.go @@ -0,0 +1,64 @@ +package pipe_test + +import ( + "context" + "io" + "os" + "os/exec" + "testing" + "time" + + "github.com/github/go-pipe/v2/pipe" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const panicChildEnv = "GO_PIPE_FUNCTION_PANIC_CHILD" +const panicSentinel = "function-panic-sentinel" + +// TestFunctionPanicWithoutHandlerPropagates verifies that when a +// `Function` stage panics and no panic handler is installed, the panic +// propagates (crashing the process) rather than being silently +// swallowed and reported as a successful run. Because a propagating +// panic would crash the test binary itself, the actual pipeline is run +// in a re-exec'd subprocess and this test asserts on its outcome. +func TestFunctionPanicWithoutHandlerPropagates(t *testing.T) { + if os.Getenv(panicChildEnv) == "1" { + runPanicChild() + return + } + + cmd := exec.Command(os.Args[0], "-test.run=^TestFunctionPanicWithoutHandlerPropagates$", "-test.v") //nolint:gosec // re-exec of this test binary with constant arguments. + cmd.Env = append(os.Environ(), panicChildEnv+"=1") + out, err := cmd.CombinedOutput() + output := string(out) + + require.Errorf(t, err, "expected subprocess to crash from a propagated panic, but it exited 0\noutput:\n%s", output) + assert.NotContains(t, output, "SURVIVED", "panic was swallowed: Run returned instead of propagating") + assert.Contains(t, output, "panic:") + assert.Contains(t, output, panicSentinel) +} + +// runPanicChild runs a pipeline whose only stage is a `Function` that panics, +// with no panic handler configured. The panic, being unhandled, should crash +// the process before the sleep elapses; if it is swallowed (the regression), +// Run returns and we print SURVIVED so the parent can detect the failure. +func runPanicChild() { + p := pipe.New(pipe.WithStdout(io.Discard)) + p.Add(pipe.Function("boom", func(_ context.Context, _ pipe.Env, _ io.Reader, _ io.Writer) error { + panic(panicSentinel) + })) + + err := p.Run(context.Background()) + + // reaching this point at all indicates the panic was swallowed. + time.Sleep(2 * time.Second) + os.Stdout.WriteString("SURVIVED: Run returned err=") + if err != nil { + os.Stdout.WriteString(err.Error()) + } else { + os.Stdout.WriteString("") + } + os.Stdout.WriteString("\n") + os.Exit(0) +} diff --git a/pipe/iocopier.go b/pipe/iocopier.go deleted file mode 100644 index 78a9143..0000000 --- a/pipe/iocopier.go +++ /dev/null @@ -1,62 +0,0 @@ -package pipe - -import ( - "context" - "errors" - "io" - "os" -) - -// ioCopier is a stage that copies its stdin to a specified -// `io.Writer`. It generates no stdout itself. -type ioCopier struct { - w io.WriteCloser - done chan struct{} - err error -} - -func newIOCopier(w io.WriteCloser) *ioCopier { - return &ioCopier{ - w: w, - done: make(chan struct{}), - } -} - -func (s *ioCopier) Name() string { - return "ioCopier" -} - -// This method always returns `nil, nil`. -func (s *ioCopier) Start(_ context.Context, _ Env, r io.ReadCloser) (io.ReadCloser, error) { - go func() { - _, err := io.Copy(s.w, r) - // We don't consider `ErrClosed` an error (FIXME: is this - // correct?): - if err != nil && !errors.Is(err, os.ErrClosed) { - s.err = err - } - if err := r.Close(); err != nil && s.err == nil { - s.err = err - } - if err := s.w.Close(); err != nil && s.err == nil { - s.err = err - } - close(s.done) - }() - - // FIXME: if `s.w.Write()` is blocking (e.g., because there is a - // downstream process that is not reading from the other side), - // there's no way to terminate the copy when the context expires. - // This is not too bad, because the `io.Copy()` call will exit by - // itself when its input is closed. - // - // We could, however, be smarter about exiting more quickly if the - // context expires but `s.w.Write()` is not blocking. - - return nil, nil -} - -func (s *ioCopier) Wait() error { - <-s.done - return s.err -} diff --git a/pipe/memorylimit.go b/pipe/memorylimit.go deleted file mode 100644 index 8e91dc1..0000000 --- a/pipe/memorylimit.go +++ /dev/null @@ -1,310 +0,0 @@ -package pipe - -import ( - "context" - "errors" - "fmt" - "io" - "sync" - "time" -) - -const memoryPollInterval = time.Second - -// ErrMemoryLimitExceeded is the error that will be used to kill a process, if -// necessary, from MemoryLimit. -var ErrMemoryLimitExceeded = errors.New("memory limit exceeded") - -// LimitableStage is the superset of Stage that must be implemented by stages -// passed to MemoryLimit and MemoryObserver. -type LimitableStage interface { - Stage - - GetRSSAnon(context.Context) (uint64, error) - Kill(error) -} - -// MemoryLimit watches the memory usage of the stage and stops it if it -// exceeds the given limit. -func MemoryLimit(stage Stage, byteLimit uint64, eventHandler func(e *Event)) Stage { - - limitableStage, ok := stage.(LimitableStage) - if !ok { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "invalid pipe.MemoryLimit usage", - Err: fmt.Errorf("invalid pipe.MemoryLimit usage"), - }) - return stage - } - - return &memoryWatchStage{ - nameSuffix: " with memory limit", - stage: limitableStage, - watch: killAtLimit(byteLimit, eventHandler), - } -} - -func killAtLimit(byteLimit uint64, eventHandler func(e *Event)) memoryWatchFunc { - return func(ctx context.Context, stage LimitableStage) { - var consecutiveErrors int - - t := time.NewTicker(memoryPollInterval) - defer t.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-t.C: - rss, err := stage.GetRSSAnon(ctx) - if err != nil && !errors.Is(err, errProcessInfoMissing) { - consecutiveErrors++ - if consecutiveErrors >= 2 { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "error getting RSS", - Err: err, - }) - } - continue - } - consecutiveErrors = 0 - if rss < byteLimit { - continue - } - eventHandler(&Event{ - Command: stage.Name(), - Msg: "stage exceeded allowed memory use", - Err: fmt.Errorf("stage exceeded allowed memory use"), - Context: map[string]interface{}{ - "limit": byteLimit, - "used": rss, - }, - }) - stage.Kill(ErrMemoryLimitExceeded) - return - } - } - } -} - -// MemoryLimitWithObserver combines MemoryLimit and MemoryObserver into a single -// stage that uses one goroutine instead of two. It watches the memory usage of -// the stage, kills the process if it exceeds byteLimit, and logs peak memory -// usage when the stage exits. -// -// Use this instead of MemoryLimit(MemoryObserver(stage, h), limit, h) to save -// one goroutine per pipeline stage. -func MemoryLimitWithObserver(stage Stage, byteLimit uint64, eventHandler func(e *Event)) Stage { - limitableStage, ok := stage.(LimitableStage) - if !ok { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "invalid pipe.MemoryLimitWithObserver usage", - Err: fmt.Errorf("invalid pipe.MemoryLimitWithObserver usage"), - }) - return stage - } - - return &memoryWatchStage{ - nameSuffix: " with memory limit", - stage: limitableStage, - watch: killAtLimitAndObserve(byteLimit, eventHandler), - } -} - -func killAtLimitAndObserve(byteLimit uint64, eventHandler func(e *Event)) memoryWatchFunc { - return func(ctx context.Context, stage LimitableStage) { - var ( - maxRSS uint64 - samples, errCount, consecutiveErrors int - killed bool - ) - - t := time.NewTicker(memoryPollInterval) - defer t.Stop() - - for { - select { - case <-ctx.Done(): - eventHandler(&Event{ - Command: stage.Name(), - Msg: "peak memory usage", - Context: map[string]interface{}{ - "max_rss_bytes": maxRSS, - "samples": samples, - "errors": errCount, - }, - }) - return - case <-t.C: - if killed { - continue - } - - rss, err := stage.GetRSSAnon(ctx) - if err != nil { - if !errors.Is(err, errProcessInfoMissing) { - errCount++ - consecutiveErrors++ - if consecutiveErrors == 2 { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "error getting RSS", - Err: err, - }) - } - } else { - consecutiveErrors = 0 - } - continue - } - - consecutiveErrors = 0 - samples++ - if rss > maxRSS { - maxRSS = rss - } - - if rss >= byteLimit { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "stage exceeded allowed memory use", - Err: fmt.Errorf("stage exceeded allowed memory use"), - Context: map[string]interface{}{ - "limit": byteLimit, - "used": rss, - }, - }) - stage.Kill(ErrMemoryLimitExceeded) - killed = true - } - } - } - } -} - -// MemoryObserver watches memory use of the stage and logs the maximum -// value when the stage exits. -func MemoryObserver(stage Stage, eventHandler func(e *Event)) Stage { - limitableStage, ok := stage.(LimitableStage) - if !ok { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "invalid pipe.MemoryObserver usage", - Err: fmt.Errorf("invalid pipe.MemoryObserver usage"), - }) - return stage - } - - return &memoryWatchStage{ - stage: limitableStage, - watch: logMaxRSS(eventHandler), - } -} - -func logMaxRSS(eventHandler func(e *Event)) memoryWatchFunc { - - return func(ctx context.Context, stage LimitableStage) { - var ( - maxRSS uint64 - samples, errors, consecutiveErrors int - ) - - t := time.NewTicker(memoryPollInterval) - defer t.Stop() - - for { - select { - case <-ctx.Done(): - eventHandler(&Event{ - Command: stage.Name(), - Msg: "peak memory usage", - Context: map[string]interface{}{ - "max_rss_bytes": maxRSS, - "samples": samples, - "errors": errors, - }, - }) - - return - case <-t.C: - rss, err := stage.GetRSSAnon(ctx) - if err != nil { - errors++ - consecutiveErrors++ - if consecutiveErrors == 2 { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "error getting RSS", - Err: err, - }) - } - // don't log any more errors until we get rss successfully. - continue - } - - consecutiveErrors = 0 - samples++ - if rss > maxRSS { - maxRSS = rss - } - } - } - } -} - -type memoryWatchStage struct { - nameSuffix string - stage LimitableStage - watch memoryWatchFunc - cancel context.CancelFunc - wg sync.WaitGroup -} - -type memoryWatchFunc func(context.Context, LimitableStage) - -var _ LimitableStage = (*memoryWatchStage)(nil) - -func (m *memoryWatchStage) Name() string { - return m.stage.Name() + m.nameSuffix -} - -func (m *memoryWatchStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) { - io, err := m.stage.Start(ctx, env, stdin) - if err != nil { - return nil, err - } - - ctx, cancel := context.WithCancel(ctx) - m.cancel = cancel - m.wg.Add(1) - - go func() { - m.watch(ctx, m.stage) - m.wg.Done() - }() - - return io, nil -} - -func (m *memoryWatchStage) Wait() error { - err := m.stage.Wait() - m.stopWatching() - return err -} - -func (m *memoryWatchStage) GetRSSAnon(ctx context.Context) (uint64, error) { - return m.stage.GetRSSAnon(ctx) -} - -func (m *memoryWatchStage) Kill(err error) { - m.stage.Kill(err) - m.stopWatching() -} - -func (m *memoryWatchStage) stopWatching() { - m.cancel() - m.wg.Wait() -} diff --git a/pipe/memorylimit_test.go b/pipe/memorylimit_test.go deleted file mode 100644 index 37d3edd..0000000 --- a/pipe/memorylimit_test.go +++ /dev/null @@ -1,277 +0,0 @@ -package pipe_test - -import ( - "bytes" - "context" - "fmt" - "io" - "log" - "os" - "strings" - "testing" - "time" - - "github.com/github/go-pipe/pipe" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func LogEventHandler(logger *log.Logger) func(*pipe.Event) { - return func(e *pipe.Event) { - ctx := "" - for k, v := range e.Context { - ctx = fmt.Sprintf("%s,%s=%v", ctx, k, v) - } - - logger.Printf("Command %s failed with message %s and error %s. Context: %s", e.Command, e.Msg, e.Command, ctx) - } -} - -func TestMemoryObserverSimple(t *testing.T) { - t.Parallel() - rss := testMemoryObserver(t, 400, pipe.Command("less")) - require.Greater(t, rss, 400_000_000) -} - -func TestMemoryObserverTreeMem(t *testing.T) { - t.Parallel() - - // create a process tree like: - // /tmp/go-build3166037414/b001/pipe.test -test.paniconexit0 -test.timeout=10m0s -test.count=1 -test.run=Tree -test.v=true - // \_ head -c 3G - // \_ sh -c less; : - // \_ less - // so that MemoryObserver is watching the parent `sh` proc and doesn't detect less's mem usage. - // less should buffer whatever we send it via stdin, giving us some level of control over its - // memory usage. - rss := testMemoryObserver(t, 400, pipe.Command("sh", "-c", "less; :")) - require.Greater(t, rss, 400_000_000) -} - -func testMemoryObserver(t *testing.T, mbs int, stage pipe.Stage) int { - ctx := context.Background() - - stdinReader, stdinWriter := io.Pipe() - - devNull, err := os.OpenFile("/dev/null", os.O_WRONLY, 0) - require.NoError(t, err) - - buf := &bytes.Buffer{} - logger := log.New(buf, "testMemoryObserver", log.Ldate|log.Ltime) - - p := pipe.New(pipe.WithDir("/"), pipe.WithStdin(stdinReader), pipe.WithStdout(devNull)) - p.Add(pipe.MemoryObserver(stage, LogEventHandler(logger))) - require.NoError(t, p.Start(ctx)) - - // Write some nonsense data to less, but don't close stdin until we want it - // to exit. - var bytes [1_000_000]byte - for i := 0; i < mbs; i++ { - n, err := stdinWriter.Write(bytes[:]) - require.NoError(t, err) - require.Equal(t, len(bytes), n) - } - - // MemoryObserver polls every one second, so this should make sure we catch - // at least one. - time.Sleep(2 * time.Second) - - // Close stdin and wait for the pipeline to exit. - require.NoError(t, stdinWriter.Close()) - require.NoError(t, p.Wait()) - - return maxBytes(buf.String()) -} - -func maxBytes(s string) int { - idx := strings.Index(s, "max_rss_bytes=") - if idx < 0 { - return idx - } - var maxRSS int - n, err := fmt.Sscanf(s[idx:], "max_rss_bytes=%d", &maxRSS) - if n != 1 || err != nil { - return -1 - } - return maxRSS -} - -func TestMemoryLimitSimple(t *testing.T) { - t.Parallel() - msg, err := testMemoryLimit(t, 400, 10_000_000, pipe.Command("less")) - assert.Contains(t, msg, "exceeded allowed memory") - assert.Contains(t, msg, "limit=10000000") - require.ErrorContains(t, err, "memory limit exceeded") -} - -func TestMemoryLimitTreeMem(t *testing.T) { - t.Parallel() - msg, err := testMemoryLimit(t, 400, 10_000_000, pipe.Command("sh", "-c", "less; :")) - assert.Contains(t, msg, "exceeded allowed memory") - assert.Contains(t, msg, "limit=10000000") - require.ErrorContains(t, err, "memory limit exceeded") -} - -type closeWrapper struct { - io.Writer - close func() error -} - -func (w closeWrapper) Close() error { - return w.close() -} - -func TestMemoryLimitWithObserverSimple(t *testing.T) { - t.Parallel() - msg, err := testMemoryLimitWithObserver(t, 400, 10_000_000, pipe.Command("less")) - assert.Contains(t, msg, "exceeded allowed memory") - assert.Contains(t, msg, "limit=10000000") - require.ErrorContains(t, err, "memory limit exceeded") -} - -func TestMemoryLimitWithObserverTreeMem(t *testing.T) { - t.Parallel() - msg, err := testMemoryLimitWithObserver(t, 400, 10_000_000, pipe.Command("sh", "-c", "less; :")) - assert.Contains(t, msg, "exceeded allowed memory") - assert.Contains(t, msg, "limit=10000000") - require.ErrorContains(t, err, "memory limit exceeded") -} - -func TestMemoryLimitWithObserverBelowLimit(t *testing.T) { - t.Parallel() - rss := testMemoryLimitWithObserverBelowLimit(t, 400, pipe.Command("less")) - require.Greater(t, rss, 400_000_000) -} - -func TestMemoryLimitWithObserverBelowLimitTreeMem(t *testing.T) { - t.Parallel() - rss := testMemoryLimitWithObserverBelowLimit(t, 400, pipe.Command("sh", "-c", "less; :")) - require.Greater(t, rss, 400_000_000) -} - -func TestMemoryLimitWithObserverLogsPeakOnKill(t *testing.T) { - t.Parallel() - msg, err := testMemoryLimitWithObserver(t, 400, 10_000_000, pipe.Command("less")) - // Verify both limit-exceeded AND peak memory are logged (matching - // the behavior of MemoryLimit(MemoryObserver(...))) - assert.Contains(t, msg, "exceeded allowed memory") - assert.Contains(t, msg, "peak memory usage") - require.ErrorContains(t, err, "memory limit exceeded") -} - -func testMemoryLimitWithObserverBelowLimit(t *testing.T, mbs int, stage pipe.Stage) int { - ctx := context.Background() - - stdinReader, stdinWriter := io.Pipe() - - devNull, err := os.OpenFile("/dev/null", os.O_WRONLY, 0) - require.NoError(t, err) - - buf := &bytes.Buffer{} - logger := log.New(buf, "testMemoryLimitWithObserver", log.Ldate|log.Ltime) - - // Use a high limit so it won't be hit — we want to verify the observer part - p := pipe.New(pipe.WithDir("/"), pipe.WithStdin(stdinReader), pipe.WithStdout(devNull)) - p.Add(pipe.MemoryLimitWithObserver(stage, 100*1024*1024*1024, LogEventHandler(logger))) - require.NoError(t, p.Start(ctx)) - - var bytes [1_000_000]byte - for i := 0; i < mbs; i++ { - n, err := stdinWriter.Write(bytes[:]) - require.NoError(t, err) - require.Equal(t, len(bytes), n) - } - - time.Sleep(2 * time.Second) - - require.NoError(t, stdinWriter.Close()) - require.NoError(t, p.Wait()) - - // Verify that peak memory usage was logged (the observer part) - output := buf.String() - assert.Contains(t, output, "peak memory usage") - - return maxBytes(output) -} - -func testMemoryLimitWithObserver(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (string, error) { - ctx := context.Background() - - stdinReader, stdinWriter := io.Pipe() - - devNull, err := os.OpenFile("/dev/null", os.O_WRONLY, 0) - require.NoError(t, err) - - closedErr := fmt.Errorf("stdout was closed") - stdout := closeWrapper{ - Writer: devNull, - close: func() error { - require.NoError(t, stdinReader.CloseWithError(closedErr)) - return nil - }, - } - - buf := &bytes.Buffer{} - logger := log.New(buf, "testMemoryLimitWithObserver", log.Ldate|log.Ltime) - - p := pipe.New(pipe.WithDir("/"), pipe.WithStdin(stdinReader), pipe.WithStdoutCloser(stdout)) - p.Add(pipe.MemoryLimitWithObserver(stage, limit, LogEventHandler(logger))) - require.NoError(t, p.Start(ctx)) - - var bytes [1_000_000]byte - for i := 0; i < mbs; i++ { - _, err := stdinWriter.Write(bytes[:]) - if err != nil { - require.ErrorIs(t, err, closedErr) - } - } - - require.NoError(t, stdinWriter.Close()) - err = p.Wait() - - return buf.String(), err -} - -func testMemoryLimit(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (string, error) { - ctx := context.Background() - - stdinReader, stdinWriter := io.Pipe() - - devNull, err := os.OpenFile("/dev/null", os.O_WRONLY, 0) - require.NoError(t, err) - - // io.Pipe doesn't know if anything is listening on the other end, so once - // our process is expectedly killed then we'll end up blocked trying to - // write to it. To workaround this, make sure we close the pipe reader when - // we've detected that the process has exited (i.e. when stdout has been - // closed). This will cause our write to immediately fail with this error. - closedErr := fmt.Errorf("stdout was closed") - stdout := closeWrapper{ - Writer: devNull, - close: func() error { - require.NoError(t, stdinReader.CloseWithError(closedErr)) - return nil - }, - } - - buf := &bytes.Buffer{} - logger := log.New(buf, "testMemoryObserver", log.Ldate|log.Ltime) - - p := pipe.New(pipe.WithDir("/"), pipe.WithStdin(stdinReader), pipe.WithStdoutCloser(stdout)) - p.Add(pipe.MemoryLimit(stage, limit, LogEventHandler(logger))) - require.NoError(t, p.Start(ctx)) - - // Write some nonsense data to less. - var bytes [1_000_000]byte - for i := 0; i < mbs; i++ { - _, err := stdinWriter.Write(bytes[:]) - if err != nil { - require.ErrorIs(t, err, closedErr) - } - } - - require.NoError(t, stdinWriter.Close()) - err = p.Wait() - - return buf.String(), err -} diff --git a/pipe/nop_closer.go b/pipe/nop_closer.go deleted file mode 100644 index d435d0a..0000000 --- a/pipe/nop_closer.go +++ /dev/null @@ -1,34 +0,0 @@ -// This file is mostly copied from the Go standard library, which is: -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. -package pipe - -import "io" - -// newNopCloser returns a ReadCloser with a no-op Close method wrapping -// the provided io.Reader r. -// If r implements io.WriterTo, the returned io.ReadCloser will implement io.WriterTo -// by forwarding calls to r. -func newNopCloser(r io.Reader) io.ReadCloser { - if _, ok := r.(io.WriterTo); ok { - return nopCloserWriterTo{r} - } - return nopCloser{r} -} - -type nopCloser struct { - io.Reader -} - -func (nopCloser) Close() error { return nil } - -type nopCloserWriterTo struct { - io.Reader -} - -func (nopCloserWriterTo) Close() error { return nil } - -func (c nopCloserWriterTo) WriteTo(w io.Writer) (n int64, err error) { - return c.Reader.(io.WriterTo).WriteTo(w) -} diff --git a/pipe/nop_closer_test.go b/pipe/nop_closer_test.go new file mode 100644 index 0000000..5a819c9 --- /dev/null +++ b/pipe/nop_closer_test.go @@ -0,0 +1,32 @@ +package pipe + +import ( + "bytes" + "context" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestGoStageReceivesConcreteWriterToStdin verifies that a Function stage +// receives its stdin as the caller's concrete type, +// so fast-path interfaces such as io.WriterTo survive. This guards +// against regressions where the concrete type is hidden behind a wrapper. +func TestGoStageReceivesConcreteWriterToStdin(t *testing.T) { + src := bytes.NewReader([]byte("hello")) + + var got io.Reader + p := New(WithStdin(src), WithStdout(io.Discard)) + p.Add(Function("capture", func(_ context.Context, _ Env, stdin io.Reader, _ io.Writer) error { + got = stdin + _, err := io.Copy(io.Discard, stdin) + return err + })) + + require.NoError(t, p.Run(context.Background())) + + assert.Same(t, src, got) + assert.Implements(t, (*io.WriterTo)(nil), got) +} diff --git a/pipe/panic.go b/pipe/panic.go deleted file mode 100644 index e0ca600..0000000 --- a/pipe/panic.go +++ /dev/null @@ -1,12 +0,0 @@ -package pipe - -// StagePanicHandlerAware is an interface that Stages can implement to receive -// a panic handler from the pipeline. This is particularly useful for stages -// that execute work in a separate goroutine and need to manage panics occurring -// within that goroutine. -type StagePanicHandlerAware interface { - SetPanicHandler(StagePanicHandler) -} - -// StagePanicHandler is a function that handles panics in a pipeline's stages. -type StagePanicHandler func(p any) error diff --git a/pipe/pipe_matching_test.go b/pipe/pipe_matching_test.go new file mode 100644 index 0000000..89c5ae7 --- /dev/null +++ b/pipe/pipe_matching_test.go @@ -0,0 +1,391 @@ +package pipe_test + +import ( + "context" + "fmt" + "io" + "os" + "testing" + + "github.com/github/go-pipe/v2/pipe" + "github.com/stretchr/testify/assert" +) + +// Tests that `Pipeline.Start()` uses the correct types of pipes in +// various situations. +// +// The type of pipe to use depends on both the source and the consumer +// of the data, including the overall pipeline's stdin and stdout. So +// there are a lot of possibilities to consider. + +type ioExpectation int + +const ( + expectOther ioExpectation = iota + expectFile + + // expectNil means that the stage should be passed a nil stdin / stdout, + // which happens at the beginning / end of a pipeline when no overall + // stdin / stdout is configured. + expectNil +) + +func file(t *testing.T) *os.File { + f, err := os.Open(os.DevNull) + assert.NoError(t, err) + return f +} + +func readCloser() io.ReadCloser { + r, w := io.Pipe() + w.Close() + return r +} + +func writeCloser() io.WriteCloser { + r, w := io.Pipe() + r.Close() + return w +} + +func newPipeSniffingStage( + req pipe.StageRequirements, stdinExpectation, stdoutExpectation ioExpectation, +) *pipeSniffingStage { + return &pipeSniffingStage{ + requirements: req, + expect: pipeExpectations{ + stdin: stdinExpectation, + stdout: stdoutExpectation, + }, + } +} + +func newPipeSniffingFunc( + stdinExpectation, stdoutExpectation ioExpectation, +) *pipeSniffingStage { + return newPipeSniffingStage( + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamAcceptAny, + }, + stdinExpectation, stdoutExpectation, + ) +} + +func newPipeSniffingCmd( + stdinExpectation, stdoutExpectation ioExpectation, +) *pipeSniffingStage { + return newPipeSniffingStage( + pipe.StageRequirements{ + Stdin: pipe.StreamPreferFile, + Stdout: pipe.StreamPreferFile, + }, + stdinExpectation, stdoutExpectation, + ) +} + +type pipeExpectations struct { + stdin ioExpectation + stdout ioExpectation +} + +type pipeSniffingStage struct { + requirements pipe.StageRequirements + expect pipeExpectations + stdin io.Reader + stdout io.Writer +} + +func (*pipeSniffingStage) Name() string { + return "pipe-sniffer" +} + +func (s *pipeSniffingStage) Requirements() pipe.StageRequirements { + return s.requirements +} + +func (s *pipeSniffingStage) Start( + _ context.Context, _ pipe.StageOptions, + stdin *pipe.InputStream, stdout *pipe.OutputStream, +) error { + s.stdin = stdin.Reader() + _ = stdin.Close() + s.stdout = stdout.Writer() + _ = stdout.Close() + return nil +} + +func (s *pipeSniffingStage) check(t *testing.T, i int) { + t.Helper() + + checkStdinExpectation(t, i, s.expect.stdin, s.stdin) + checkStdoutExpectation(t, i, s.expect.stdout, s.stdout) +} + +func (s *pipeSniffingStage) Wait() error { + return nil +} + +var _ pipe.Stage = (*pipeSniffingStage)(nil) + +func ioTypeString(f any) string { + if f == nil { + return "nil" + } + switch f := f.(type) { + case *os.File: + return "*os.File" + case io.Reader: + return "other" + case io.Writer: + return "other" + default: + return fmt.Sprintf("%T", f) + } +} + +func expectationString(expect ioExpectation) string { + switch expect { + case expectOther: + return "other" + case expectFile: + return "*os.File" + case expectNil: + return "nil" + default: + panic(fmt.Sprintf("invalid ioExpectation: %d", expect)) + } +} + +func checkStdinExpectation(t *testing.T, i int, expect ioExpectation, stdin io.Reader) { + t.Helper() + + ioType := ioTypeString(stdin) + expType := expectationString(expect) + assert.Equalf( + t, expType, ioType, + "stage %d stdin: expected %s, got %s (%T)", i, expType, ioType, stdin, + ) +} + +func checkStdoutExpectation(t *testing.T, i int, expect ioExpectation, stdout io.Writer) { + t.Helper() + + ioType := ioTypeString(stdout) + expType := expectationString(expect) + assert.Equalf( + t, expType, ioType, + "stage %d stdout: expected %s, got %s (%T)", i, expType, ioType, stdout, + ) +} + +type checker interface { + check(t *testing.T, i int) +} + +func TestPipeTypes(t *testing.T) { + ctx := context.Background() + + t.Parallel() + + for _, tc := range []struct { + name string + opts []pipe.Option + stages []pipe.Stage + stdin io.Reader + stdout io.Writer + }{ + { + name: "func", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingFunc(expectNil, expectNil), + }, + }, + { + name: "func-file-stdin", + opts: []pipe.Option{ + pipe.WithStdin(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(expectFile, expectNil), + }, + }, + { + name: "func-file-stdout", + opts: []pipe.Option{ + pipe.WithStdout(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(expectNil, expectFile), + }, + }, + { + name: "func-file-stdout-closer", + opts: []pipe.Option{ + pipe.WithStdoutCloser(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(expectNil, expectFile), + }, + }, + { + name: "func-file-stdin-other-stdout-closer-other", + opts: []pipe.Option{ + pipe.WithStdin(readCloser()), + pipe.WithStdoutCloser(writeCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(expectOther, expectOther), + }, + }, + { + name: "cmd", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingCmd(expectNil, expectNil), + }, + }, + { + name: "cmd-file-stdin", + opts: []pipe.Option{ + pipe.WithStdin(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingCmd(expectFile, expectNil), + }, + }, + { + name: "cmd-file-stdout", + opts: []pipe.Option{ + pipe.WithStdout(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingCmd(expectNil, expectFile), + }, + }, + { + name: "cmd-file-stdout-closer", + opts: []pipe.Option{ + pipe.WithStdoutCloser(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingCmd(expectNil, expectFile), + }, + }, + { + name: "cmd-file-stdin-other-stdout-closer-other", + opts: []pipe.Option{ + pipe.WithStdin(readCloser()), + pipe.WithStdoutCloser(writeCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingCmd(expectOther, expectOther), + }, + }, + { + name: "func-func", + opts: []pipe.Option{ + pipe.WithStdin(file(t)), + pipe.WithStdoutCloser(writeCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(expectFile, expectOther), + newPipeSniffingFunc(expectOther, expectOther), + }, + }, + { + name: "func-cmd", + opts: []pipe.Option{ + pipe.WithStdout(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(expectNil, expectFile), + newPipeSniffingCmd(expectFile, expectFile), + }, + }, + { + name: "cmd-func", + opts: []pipe.Option{ + pipe.WithStdin(readCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingCmd(expectOther, expectFile), + newPipeSniffingFunc(expectFile, expectNil), + }, + }, + { + name: "cmd-cmd", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingCmd(expectNil, expectFile), + newPipeSniffingCmd(expectFile, expectNil), + }, + }, + { + name: "hybrid1", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingStage( + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamAcceptAny, + }, + expectNil, expectOther, + ), + newPipeSniffingStage( + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamPreferFile, + }, + expectOther, expectFile, + ), + newPipeSniffingStage( + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamAcceptAny, + }, + expectFile, expectNil, + ), + }, + }, + { + name: "hybrid2", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingStage( + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamAcceptAny, + }, + expectNil, expectFile, + ), + newPipeSniffingStage( + pipe.StageRequirements{ + Stdin: pipe.StreamPreferFile, + Stdout: pipe.StreamAcceptAny, + }, + expectFile, expectOther, + ), + newPipeSniffingStage( + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamAcceptAny, + }, + expectOther, expectNil, + ), + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := pipe.New(tc.opts...) + p.Add(tc.stages...) + assert.NoError(t, p.Run(ctx)) + for i, s := range tc.stages { + s.(checker).check(t, i) + } + }) + } +} diff --git a/pipe/pipeline.go b/pipe/pipeline.go index 8bc3a37..878cdc6 100644 --- a/pipe/pipeline.go +++ b/pipe/pipeline.go @@ -54,8 +54,8 @@ type ContextValuesFunc func(context.Context) []EnvVar type Pipeline struct { env Env - stdin io.Reader - stdout io.WriteCloser + stdin *InputStream + stdout *OutputStream stages []Stage cancel func() @@ -70,14 +70,6 @@ type Pipeline struct { var emptyEventHandler = func(_ *Event) {} -type nopWriteCloser struct { - io.Writer -} - -func (w nopWriteCloser) Close() error { - return nil -} - type NewPipeFn func(opts ...Option) *Pipeline // NewPipeline returns a Pipeline struct with all of the `options` @@ -104,25 +96,31 @@ func WithDir(dir string) Option { } } -// WithStdin assigns stdin to the first command in the pipeline. +// WithStdin assigns stdin to the first command in the pipeline. The +// caller retains ownership of stdin; the pipeline will not close it, +// even if `Start()` returns an error. func WithStdin(stdin io.Reader) Option { return func(p *Pipeline) { - p.stdin = stdin + p.stdin = Input(stdin) } } -// WithStdout assigns stdout to the last command in the pipeline. +// WithStdout assigns stdout to the last command in the pipeline. The +// caller retains ownership of stdout; the pipeline will not close it, +// even if `Start()` returns an error. func WithStdout(stdout io.Writer) Option { return func(p *Pipeline) { - p.stdout = nopWriteCloser{stdout} + p.stdout = Output(stdout) } } // WithStdoutCloser assigns stdout to the last command in the -// pipeline, and closes stdout when it's done. +// pipeline, and closes stdout when the pipeline is done with it. The +// pipeline is responsible for closing stdout even if `Start()` returns +// an error. func WithStdoutCloser(stdout io.WriteCloser) Option { return func(p *Pipeline) { - p.stdout = stdout + p.stdout = ClosingOutput(stdout) } } @@ -186,7 +184,6 @@ func WithEventHandler(handler func(e *Event)) Option { // the client to handle the panic in whatever way they see fit. // // Note: -// - Only the Function stage supports this functionality. // - The client is responsible for deciding whether to recover from the panic or panicking again. // - If a panic handler is not set, the panic will be propagated normally. func WithStagePanicHandler(ph StagePanicHandler) Option { @@ -220,9 +217,20 @@ func (p *Pipeline) AddWithIgnoredError(em ErrorMatcher, stages ...Stage) { } } +func (p *Pipeline) stageOptions() StageOptions { + return StageOptions{Env: p.env, PanicHandler: p.panicHandler} +} + // Start starts the commands in the pipeline. If `Start()` exits // without an error, `Wait()` must also be called, to allow all // resources to be freed. +// +// If `Start()` returns an error, `Wait()` must not be called. Before +// returning an error, `Start()` cancels and waits for any stages that +// were started, closes any inter-stage pipes that the pipeline owns, +// and closes stdout if it was supplied with `WithStdoutCloser()`. +// Streams supplied with `WithStdin()` or `WithStdout()` remain owned by +// the caller and are not closed by the pipeline. func (p *Pipeline) Start(ctx context.Context) error { if p.hasStarted() { panic("attempt to start a pipeline that has already started") @@ -231,93 +239,127 @@ func (p *Pipeline) Start(ctx context.Context) error { atomic.StoreUint32(&p.started, 1) ctx, p.cancel = context.WithCancel(ctx) - var nextStdin io.ReadCloser - if p.stdin != nil { - // We don't want the first stage to actually close this, and - // `p.stdin` is not even necessarily an `io.ReadCloser`. So - // wrap it in a fake `io.ReadCloser` whose `Close()` method - // doesn't do anything. - // - // We could use `io.NopCloser()` for this purpose, but it has - // a subtle problem. If the first stage is a `Command`, then - // it wants to set the `exec.Cmd`'s `Stdin` to an `io.Reader` - // corresponding to `p.stdin`. If `Cmd.Stdin` is an - // `*os.File`, then the file descriptor can be passed to the - // subcommand directly; there is no need for this process to - // create a pipe and copy the data into the input side of the - // pipe. But if `p.stdin` is not an `*os.File`, then this - // optimization is prevented. And even worse, it also has the - // side effect that the goroutine that copies from `Cmd.Stdin` - // into the pipe doesn't terminate until that fd is closed by - // the writing side. - // - // That isn't always what we want. Consider, for example, the - // following snippet, where the subcommand's stdin is set to - // the stdin of the enclosing Go program, but wrapped with - // `io.NopCloser`: - // - // cmd := exec.Command("ls") - // cmd.Stdin = io.NopCloser(os.Stdin) - // cmd.Stdout = os.Stdout - // cmd.Stderr = os.Stderr - // cmd.Run() - // - // In this case, we don't want the Go program to wait for - // `os.Stdin` to close (because `ls` isn't even trying to read - // from its stdin). But it does: `exec.Cmd` doesn't recognize - // that `Cmd.Stdin` is an `*os.File`, so it sets up a pipe and - // copies the data itself, and this goroutine doesn't - // terminate until `cmd.Stdin` (i.e., the Go program's own - // stdin) is closed. But if, for example, the Go program is - // run from an interactive shell session, that might never - // happen, in which case the program will fail to terminate, - // even after `ls` exits. - // - // So instead, in this special case, we wrap `p.stdin` in our - // own `nopCloser`, which behaves like `io.NopCloser`, except - // that `pipe.CommandStage` knows how to unwrap it before - // passing it to `exec.Cmd`. - nextStdin = newNopCloser(p.stdin) + if len(p.stages) == 0 { + if p.stdout == nil { + // No stages and no destination: there is nothing to do + // and nowhere to put `p.stdin` even if it was set. + return nil + } + // No stages but a destination was configured: synthesize an + // identity-copy stage so that `WithStdin()` is drained into + // `WithStdout()`/`WithStdoutCloser()` and the destination + // closer (if any) is invoked. + p.stages = append(p.stages, Function( + "identity", + func(_ context.Context, _ Env, stdin io.Reader, stdout io.Writer) error { + if stdin == nil { + return nil + } + _, err := io.Copy(stdout, stdin) + return err + }, + )) + } + + // We need to decide how to start the stages, especially what + // pipes to use to connect adjacent stages (`os.Pipe()` vs. + // `io.Pipe()`) based on the two stages' requirements. + stageJoiners := make([]stageJoiner, len(p.stages)+1) + + // Arrange for the input of the 0th stage to come from `p.stdin`: + stageJoiners[0].nextStdin = p.stdin + + // Arrange for the output of the last stage to go to `p.stdout`: + stageJoiners[len(p.stages)].prevStdout = p.stdout + + // closePipes closes all of the streams that are currently stored + // in the joiners. This should be called if startup fails. As we + // call `Stage.Start()` and pass that method streams, we clear + // them from the corresponding joiners to avoid closing them + // twice. + closePipes := func() { + for _, sj := range stageJoiners { + _ = sj.closePipe() + } } + // Store the stages in the joiners, and verify that the stages' + // requirements are well-formed: for i, s := range p.stages { - if phs, ok := s.(StagePanicHandlerAware); ok && p.panicHandler != nil { - phs.SetPanicHandler(p.panicHandler) + // Make sure that the stage's requirements are well-formed: + requirements := s.Requirements() + if err := requirements.Stdin.Validate(); err != nil { + closePipes() + return fmt.Errorf("stdin: %w", err) + } + if err := requirements.Stdout.Validate(); err != nil { + closePipes() + return fmt.Errorf("stdout: %w", err) } - var err error - stdout, err := s.Start(ctx, p.env, nextStdin) - if err != nil { - // Close the pipe that the previous stage was writing to. - // That should cause it to exit even if it's not minding - // its context. - if nextStdin != nil { - _ = nextStdin.Close() - } + stageJoiners[i].nextStage = s + stageJoiners[i].nextStageReq = requirements + stageJoiners[i+1].prevStage = s + stageJoiners[i+1].prevStageReq = requirements + } - // Kill and wait for any stages that have been started - // already to finish: - p.cancel() - for _, s := range p.stages[:i] { - _ = s.Wait() - } - p.eventHandler(&Event{ - Command: s.Name(), - Msg: "failed to start pipeline stage", - Err: err, - }) - return fmt.Errorf("starting pipeline stage %q: %w", s.Name(), err) + // Check that each of the stages' requirements are satisfiable: + for i := range stageJoiners { + if err := stageJoiners[i].validate(); err != nil { + closePipes() + return err + } + } + + // Create the "inner" pipes (i.e, all but the first and last + // `stageJoiners`): + for i := 1; i < len(stageJoiners)-1; i++ { + if err := stageJoiners[i].createPipe(); err != nil { + closePipes() + return err } - nextStdin = stdout } - // If the pipeline was configured with a `stdout`, add a synthetic - // stage to copy the last stage's stdout to that writer: - if p.stdout != nil { - c := newIOCopier(p.stdout) - p.stages = append(p.stages, c) - // `ioCopier.Start()` never fails: - _, _ = c.Start(ctx, p.env, nextStdin) + // We're about to start up the stages, one by one. If something + // goes wrong during that process, this function should be called + // to kill any stages that have already been started and to close + // any pipes that have not yet been passed to a stage. `i` is the + // index of the stage that failed to start. If the stage already + // received its streams, it is responsible for closing them. + abort := func(i int, err error) error { + closePipes() + + // Kill and wait for any stages that have been started + // already to finish: + p.cancel() + for _, s := range p.stages[:i] { + _ = s.Wait() + } + p.eventHandler(&Event{ + Command: p.stages[i].Name(), + Msg: "failed to start pipeline stage", + Err: err, + }) + return fmt.Errorf( + "starting pipeline stage %q: %w", p.stages[i].Name(), err, + ) + } + + // Loop over all of the stages, starting them in order. + for i, s := range p.stages { + prevSJ := &stageJoiners[i] + nextSJ := &stageJoiners[i+1] + + err := s.Start(ctx, p.stageOptions(), prevSJ.nextStdin, nextSJ.prevStdout) + + // Even if that stage failed to start, we are no longer + // responsible for closing its streams: + prevSJ.nextStdin = nil + nextSJ.prevStdout = nil + + if err != nil { + return abort(i, err) + } } return nil @@ -325,7 +367,7 @@ func (p *Pipeline) Start(ctx context.Context) error { func (p *Pipeline) Output(ctx context.Context) ([]byte, error) { var buf bytes.Buffer - p.stdout = nopWriteCloser{&buf} + p.stdout = Output(&buf) err := p.Run(ctx) return buf.Bytes(), err } @@ -420,7 +462,9 @@ func (p *Pipeline) Wait() error { return nil } -// Run starts and waits for the commands in the pipeline. +// Run starts and waits for the commands in the pipeline. If startup +// fails, it returns the `Start()` error after `Start()` has performed +// its failure cleanup. func (p *Pipeline) Run(ctx context.Context) error { if err := p.Start(ctx); err != nil { return err diff --git a/pipe/pipeline_test.go b/pipe/pipeline_test.go index bebc931..e956964 100644 --- a/pipe/pipeline_test.go +++ b/pipe/pipeline_test.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "os" - "runtime" "strconv" "strings" "testing" @@ -18,7 +17,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/goleak" - "github.com/github/go-pipe/pipe" + "github.com/github/go-pipe/v2/pipe" ) // Check whether this package's test suite leaks any goroutines: @@ -26,31 +25,96 @@ func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } -func TestPipelineFirstStageFailsToStart(t *testing.T) { +func TestPipelineEmpty(t *testing.T) { + t.Parallel() + p := pipe.New() + assert.NoError(t, p.Run(context.Background())) +} + +func TestPipelineEmptyWithStdinAndStdout(t *testing.T) { t.Parallel() ctx := context.Background() + stdout := &bytes.Buffer{} + p := pipe.New( + pipe.WithStdin(strings.NewReader("hello world\n")), + pipe.WithStdout(stdout), + ) + if assert.NoError(t, p.Run(ctx)) { + assert.Equal(t, "hello world\n", stdout.String()) + } +} - dir := t.TempDir() +func TestPipelineEmptyOutput(t *testing.T) { + t.Parallel() + ctx := context.Background() + p := pipe.New(pipe.WithStdin(strings.NewReader("hello world\n"))) + out, err := p.Output(ctx) + if assert.NoError(t, err) { + assert.Equal(t, "hello world\n", string(out)) + } +} + +func TestPipelineEmptyWithStdoutCloser(t *testing.T) { + t.Parallel() + ctx := context.Background() + stdout := &closeTrackingWriter{} + p := pipe.New( + pipe.WithStdin(strings.NewReader("hello world\n")), + pipe.WithStdoutCloser(stdout), + ) + if assert.NoError(t, p.Run(ctx)) { + assert.Equal(t, "hello world\n", stdout.buf.String()) + assert.True(t, stdout.closed, "WithStdoutCloser destination should be closed") + } +} + +type closeTrackingWriter struct { + buf bytes.Buffer + closed bool +} + +func (w *closeTrackingWriter) Write(p []byte) (int, error) { return w.buf.Write(p) } +func (w *closeTrackingWriter) Close() error { + w.closed = true + return nil +} +func TestPipelineFirstStageFailsToStart(t *testing.T) { + t.Parallel() + ctx := context.Background() startErr := errors.New("foo") + stdout := &closeTrackingWriter{} - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New(pipe.WithStdoutCloser(stdout)) p.Add( ErrorStartingStage{startErr}, ErrorStartingStage{errors.New("this error should never happen")}, ) assert.ErrorIs(t, p.Run(ctx), startErr) + assert.True(t, stdout.closed, "WithStdoutCloser destination should be closed") } -func TestPipelineSecondStageFailsToStart(t *testing.T) { +func TestPipelineFirstStageFailsToStartClosesStdoutCloser(t *testing.T) { t.Parallel() ctx := context.Background() + startErr := errors.New("foo") + stdout := &closeTrackingWriter{} - dir := t.TempDir() + p := pipe.New(pipe.WithStdoutCloser(stdout)) + p.Add( + ErrorStartingStage{startErr}, + pipe.Command("this-stage-should-not-start"), + ) + assert.ErrorIs(t, p.Run(ctx), startErr) + assert.True(t, stdout.closed, "WithStdoutCloser destination should be closed") +} +func TestPipelineSecondStageFailsToStart(t *testing.T) { + t.Parallel() + ctx := context.Background() startErr := errors.New("foo") - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( seqFunction(20000), ErrorStartingStage{startErr}, @@ -58,13 +122,26 @@ func TestPipelineSecondStageFailsToStart(t *testing.T) { assert.ErrorIs(t, p.Run(ctx), startErr) } -func TestPipelineSingleCommandOutput(t *testing.T) { +func TestPipelineMiddleStageFailsToStartClosesUnstartedStdoutCloser(t *testing.T) { t.Parallel() ctx := context.Background() + startErr := errors.New("foo") + stdout := &closeTrackingWriter{} - dir := t.TempDir() + p := pipe.New(pipe.WithStdoutCloser(stdout)) + p.Add( + seqFunction(20000), + ErrorStartingStage{startErr}, + ErrorStartingStage{errors.New("this error should never happen")}, + ) + assert.ErrorIs(t, p.Run(ctx), startErr) + assert.True(t, stdout.closed, "WithStdoutCloser destination should be closed") +} - p := pipe.New(pipe.WithDir(dir)) +func TestPipelineSingleCommandOutput(t *testing.T) { + t.Parallel() + ctx := context.Background() + p := pipe.New() p.Add(pipe.Command("echo", "hello world")) out, err := p.Output(ctx) if assert.NoError(t, err) { @@ -75,19 +152,16 @@ func TestPipelineSingleCommandOutput(t *testing.T) { func TestPipelineSingleCommandWithStdout(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - stdout := &bytes.Buffer{} - p := pipe.New(pipe.WithDir(dir), pipe.WithStdout(stdout)) + p := pipe.New(pipe.WithStdout(stdout)) p.Add(pipe.Command("echo", "hello world")) if assert.NoError(t, p.Run(ctx)) { assert.Equal(t, "hello world\n", stdout.String()) } } -func TestPipelineStdinFileThatIsNeverClosed(t *testing.T) { +func TestPipelineStdinOSPipeThatIsNeverClosed(t *testing.T) { t.Parallel() // Make sure that the subprocess terminates on its own, as opposed @@ -105,7 +179,10 @@ func TestPipelineStdinFileThatIsNeverClosed(t *testing.T) { var stdout bytes.Buffer - p := pipe.New(pipe.WithStdin(r), pipe.WithStdout(&stdout)) + p := pipe.New( + pipe.WithStdin(r), + pipe.WithStdout(&stdout), + ) // Note that this command doesn't read from its stdin, so it will // terminate regardless of whether `w` gets closed: p.Add(pipe.Command("true")) @@ -115,7 +192,7 @@ func TestPipelineStdinFileThatIsNeverClosed(t *testing.T) { assert.NoError(t, p.Run(ctx)) } -func TestPipelineStdinThatIsNeverClosed(t *testing.T) { +func TestPipelineIOPipeStdinThatIsNeverClosed(t *testing.T) { t.Skip("test not run because it currently deadlocks") t.Parallel() @@ -158,10 +235,7 @@ func TestPipelineStdinThatIsNeverClosed(t *testing.T) { func TestNontrivialPipeline(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("echo", "hello world"), pipe.Command("sed", "s/hello/goodbye/"), @@ -172,7 +246,33 @@ func TestNontrivialPipeline(t *testing.T) { } } -func TestPipelineReadFromSlowly(t *testing.T) { +func TestOSPipePipelineReadFromSlowly(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + r, w, err := os.Pipe() + require.NoError(t, err) + + var buf []byte + readErr := make(chan error, 1) + + go func() { + time.Sleep(200 * time.Millisecond) + var err error + buf, err = io.ReadAll(r) + readErr <- err + }() + + p := pipe.New(pipe.WithStdoutCloser(w)) + p.Add(pipe.Command("echo", "hello world")) + assert.NoError(t, p.Run(ctx)) + + assert.NoError(t, <-readErr) + assert.Equal(t, "hello world\n", string(buf)) +} + +func TestIOPipePipelineReadFromSlowly(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() @@ -203,16 +303,9 @@ func TestPipelineReadFromSlowly(t *testing.T) { } func TestPipelineReadFromSlowly2(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("FIXME: test skipped on Windows: 'seq' unavailable") - } - t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - - dir := t.TempDir() - r, w := io.Pipe() var buf []byte @@ -236,7 +329,7 @@ func TestPipelineReadFromSlowly2(t *testing.T) { } }() - p := pipe.New(pipe.WithDir(dir), pipe.WithStdout(w)) + p := pipe.New(pipe.WithStdout(w)) p.Add(pipe.Command("seq", "100")) assert.NoError(t, p.Run(ctx)) @@ -252,10 +345,7 @@ func TestPipelineReadFromSlowly2(t *testing.T) { func TestPipelineTwoCommandsPiping(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add(pipe.Command("echo", "hello world")) assert.Panics(t, func() { p.Add(pipe.Command("")) }) out, err := p.Output(ctx) @@ -282,10 +372,7 @@ func TestPipelineDir(t *testing.T) { func TestPipelineExit(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("false"), pipe.Command("true"), @@ -311,16 +398,11 @@ func TestPipelineStderr(t *testing.T) { } func TestPipelineInterrupted(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("FIXME: test skipped on Windows: 'sleep' unavailable") - } - t.Parallel() - dir := t.TempDir() stdout := &bytes.Buffer{} - p := pipe.New(pipe.WithDir(dir), pipe.WithStdout(stdout)) + p := pipe.New(pipe.WithStdout(stdout)) p.Add(pipe.Command("sleep", "10")) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) @@ -334,16 +416,11 @@ func TestPipelineInterrupted(t *testing.T) { } func TestPipelineCanceled(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("FIXME: test skipped on Windows: 'sleep' unavailable") - } - t.Parallel() - dir := t.TempDir() stdout := &bytes.Buffer{} - p := pipe.New(pipe.WithDir(dir), pipe.WithStdout(stdout)) + p := pipe.New(pipe.WithStdout(stdout)) p.Add(pipe.Command("sleep", "10")) ctx, cancel := context.WithCancel(context.Background()) @@ -362,14 +439,9 @@ func TestPipelineCanceled(t *testing.T) { // unread output in this case *does fit* within the OS-level pipe // buffer. func TestLittleEPIPE(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("FIXME: test skipped on Windows: 'sleep' unavailable") - } - t.Parallel() - dir := t.TempDir() - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("sh", "-c", "sleep 1; echo foo"), pipe.Command("true"), @@ -386,14 +458,9 @@ func TestLittleEPIPE(t *testing.T) { // amount of unread output in this case *does not fit* within the // OS-level pipe buffer. func TestBigEPIPE(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("FIXME: test skipped on Windows: 'seq' unavailable") - } - t.Parallel() - dir := t.TempDir() - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("seq", "100000"), pipe.Command("true"), @@ -410,14 +477,9 @@ func TestBigEPIPE(t *testing.T) { // amount of unread output in this case *does not fit* within the // OS-level pipe buffer. func TestIgnoredSIGPIPE(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("FIXME: test skipped on Windows: 'seq' unavailable") - } - t.Parallel() - dir := t.TempDir() - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.IgnoreError(pipe.Command("seq", "100000"), pipe.IsSIGPIPE), pipe.Command("echo", "foo"), @@ -430,14 +492,76 @@ func TestIgnoredSIGPIPE(t *testing.T) { assert.EqualValues(t, "foo\n", out) } -func TestFunction(t *testing.T) { +func TestGoProducerSeesPipeErrorWhenCommandStopsReading(t *testing.T) { t.Parallel() - ctx := context.Background() - dir := t.TempDir() + p := pipe.New() + p.Add( + pipe.Function( + "write-to-closed-command", + func(_ context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error { + w := bufio.NewWriter(stdout) + for i := 0; i < 100000; i++ { + if _, err := fmt.Fprintln(w, i); err != nil { + return err + } + } + return w.Flush() + }, + ), + pipe.Command("true"), + ) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + err := p.Run(ctx) + require.Error(t, err) + assert.True(t, pipe.IsPipeError(err), "expected a pipe error, got %v", err) +} +func TestIgnoredPipeErrorStillAllowsStatefulProducerToFinish(t *testing.T) { + t.Parallel() + + const total = 100000 + processed := 0 + p := pipe.New() + p.Add( + pipe.IgnoreError( + pipe.Function( + "stateful-producer", + func(_ context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error { + w := bufio.NewWriter(stdout) + var writeErr error + for i := 0; i < total; i++ { + processed++ + if writeErr == nil { + if _, err := fmt.Fprintln(w, i); err != nil { + writeErr = err + } + } + } + if writeErr == nil { + writeErr = w.Flush() + } + return writeErr + }, + ), + pipe.IsPipeError, + ), + pipe.Command("true"), + ) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + require.NoError(t, p.Run(ctx)) + assert.Equal(t, total, processed) +} + +func TestFunction(t *testing.T) { + t.Parallel() + ctx := context.Background() t.Run("successful function", func(t *testing.T) { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Print("hello world"), pipe.Function( @@ -463,10 +587,8 @@ func TestFunction(t *testing.T) { t.Run("panic with handler", func(t *testing.T) { p := pipe.New( - pipe.WithDir(dir), pipe.WithStagePanicHandler(func(p any) error { - err := fmt.Errorf("panic handled: %v", p) - return err + return fmt.Errorf("panic handled: %v", p) }), ) p.Add( @@ -483,15 +605,36 @@ func TestFunction(t *testing.T) { assert.ErrorContains(t, err, "panic handled") assert.Empty(t, out) }) + + t.Run("panic with handler through IgnoreError", func(t *testing.T) { + p := pipe.New( + pipe.WithStagePanicHandler(func(p any) error { + return fmt.Errorf("panic handled: %v", p) + }), + ) + p.Add( + pipe.Print("hello world"), + pipe.IgnoreError( + pipe.Function( + "farewell", + func(_ context.Context, _ pipe.Env, _ io.Reader, _ io.Writer) error { + panic("this is a panic") + }, + ), + func(_ error) bool { return false }, + ), + ) + + out, err := p.Output(ctx) + assert.ErrorContains(t, err, "panic handled") + assert.Empty(t, out) + }) } func TestPipelineWithFunction(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("echo", "-n", "hello world"), pipe.Function( @@ -524,16 +667,53 @@ func (s ErrorStartingStage) Name() string { return "errorStartingStage" } +func (s ErrorStartingStage) Requirements() pipe.StageRequirements { + return pipe.StageRequirements{} +} + func (s ErrorStartingStage) Start( - _ context.Context, _ pipe.Env, _ io.ReadCloser, -) (io.ReadCloser, error) { - return io.NopCloser(&bytes.Buffer{}), s.err + _ context.Context, _ pipe.StageOptions, + stdin *pipe.InputStream, stdout *pipe.OutputStream, +) error { + _ = stdin.Close() + _ = stdout.Close() + return s.err } func (s ErrorStartingStage) Wait() error { return nil } +type requirementStage struct { + name string + requirement pipe.StageRequirements + started *bool +} + +func (s requirementStage) Name() string { + return s.name +} + +func (s requirementStage) Requirements() pipe.StageRequirements { + return s.requirement +} + +func (s requirementStage) Start( + _ context.Context, _ pipe.StageOptions, + stdin *pipe.InputStream, stdout *pipe.OutputStream, +) error { + if s.started != nil { + *s.started = true + } + _ = stdin.Close() + _ = stdout.Close() + return nil +} + +func (s requirementStage) Wait() error { + return nil +} + func seqFunction(n int) pipe.Stage { return pipe.Function( "seq", @@ -552,10 +732,7 @@ func seqFunction(n int) pipe.Stage { func TestPipelineWithLinewiseFunction(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() // Print the numbers from 1 to 20 (generated from scratch): p.Add( seqFunction(20), @@ -694,10 +871,7 @@ func TestScannerFinishEarly(t *testing.T) { func TestPrintln(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add(pipe.Println("Look Ma, no hands!")) out, err := p.Output(ctx) if assert.NoError(t, err) { @@ -708,10 +882,7 @@ func TestPrintln(t *testing.T) { func TestPrintf(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add(pipe.Printf("Strangely recursive: %T", p)) out, err := p.Output(ctx) if assert.NoError(t, err) { @@ -719,6 +890,202 @@ func TestPrintf(t *testing.T) { } } +func TestPrintlnNoOutput(t *testing.T) { + t.Parallel() + ctx := context.Background() + p := pipe.New() + p.Add(pipe.Println("Look Ma, no output!")) + assert.NoError(t, p.Run(ctx)) +} + +func TestPrintlnForbidsStdin(t *testing.T) { + t.Parallel() + ctx := context.Background() + + t.Run("pipeline stdin", func(t *testing.T) { + t.Parallel() + p := pipe.New(pipe.WithStdin(strings.NewReader("ignored"))) + p.Add(pipe.Println("Look Ma, no stdin!")) + require.ErrorContains(t, p.Run(ctx), `stage "println" forbids stdin`) + }) + + t.Run("previous stage", func(t *testing.T) { + t.Parallel() + p := pipe.New() + p.Add( + seqFunction(1), + pipe.Println("Look Ma, no previous stage!"), + ) + require.ErrorContains(t, p.Run(ctx), `stage "println" forbids stdin`) + }) +} + +func TestFunctionOptionsForbidStreams(t *testing.T) { + t.Parallel() + ctx := context.Background() + + t.Run("stdin", func(t *testing.T) { + t.Parallel() + p := pipe.New(pipe.WithStdin(strings.NewReader("ignored"))) + p.Add(pipe.Function( + "source", + func(_ context.Context, _ pipe.Env, _ io.Reader, _ io.Writer) error { + return nil + }, + pipe.ForbidStdin(), + )) + require.ErrorContains(t, p.Run(ctx), `stage "source" forbids stdin`) + }) + + t.Run("stdout", func(t *testing.T) { + t.Parallel() + p := pipe.New(pipe.WithStdout(io.Discard)) + p.Add(pipe.Function( + "sink", + func(_ context.Context, _ pipe.Env, _ io.Reader, _ io.Writer) error { + return nil + }, + pipe.ForbidStdout(), + )) + require.ErrorContains(t, p.Run(ctx), `stage "sink" forbids stdout`) + }) +} + +func TestStreamForbiddenStdin(t *testing.T) { + t.Parallel() + ctx := context.Background() + stage := requirementStage{ + name: "source", + requirement: pipe.StageRequirements{ + Stdin: pipe.StreamForbidden, + }, + } + + t.Run("without stdin", func(t *testing.T) { + t.Parallel() + p := pipe.New() + p.Add(stage) + require.NoError(t, p.Run(ctx)) + }) + + t.Run("with stdin", func(t *testing.T) { + t.Parallel() + p := pipe.New(pipe.WithStdin(strings.NewReader("ignored"))) + p.Add(stage) + require.ErrorContains(t, p.Run(ctx), `stage "source" forbids stdin`) + }) +} + +func TestStreamForbiddenStdout(t *testing.T) { + t.Parallel() + ctx := context.Background() + stage := requirementStage{ + name: "sink", + requirement: pipe.StageRequirements{ + Stdout: pipe.StreamForbidden, + }, + } + + t.Run("without stdout", func(t *testing.T) { + t.Parallel() + p := pipe.New() + p.Add(stage) + require.NoError(t, p.Run(ctx)) + }) + + t.Run("with stdout", func(t *testing.T) { + t.Parallel() + p := pipe.New(pipe.WithStdout(io.Discard)) + p.Add(stage) + require.ErrorContains(t, p.Run(ctx), `stage "sink" forbids stdout`) + }) + + t.Run("with stdout closer", func(t *testing.T) { + t.Parallel() + stdout := &closeTrackingWriter{} + p := pipe.New(pipe.WithStdoutCloser(stdout)) + p.Add(stage) + require.ErrorContains(t, p.Run(ctx), `stage "sink" forbids stdout`) + assert.True(t, stdout.closed, "WithStdoutCloser destination should be closed") + }) +} + +func TestInvalidStreamRequirements(t *testing.T) { + t.Parallel() + ctx := context.Background() + + t.Run("stdin", func(t *testing.T) { + t.Parallel() + stdout := &closeTrackingWriter{} + p := pipe.New(pipe.WithStdoutCloser(stdout)) + p.Add(requirementStage{ + name: "source", + requirement: pipe.StageRequirements{ + Stdin: pipe.StreamRequirement(123), + }, + }) + require.ErrorContains(t, p.Run(ctx), `stdin: invalid stream requirement 123`) + assert.True(t, stdout.closed, "WithStdoutCloser destination should be closed") + }) + + t.Run("stdout", func(t *testing.T) { + t.Parallel() + stdout := &closeTrackingWriter{} + p := pipe.New(pipe.WithStdoutCloser(stdout)) + p.Add(requirementStage{ + name: "sink", + requirement: pipe.StageRequirements{ + Stdout: pipe.StreamRequirement(123), + }, + }) + require.ErrorContains(t, p.Run(ctx), `stdout: invalid stream requirement 123`) + assert.True(t, stdout.closed, "WithStdoutCloser destination should be closed") + }) +} + +func TestStreamForbiddenMiddleStage(t *testing.T) { + t.Parallel() + ctx := context.Background() + var started bool + p := pipe.New() + p.Add( + requirementStage{name: "previous", started: &started}, + requirementStage{ + name: "middle-source", + requirement: pipe.StageRequirements{ + Stdin: pipe.StreamForbidden, + }, + }, + ) + require.ErrorContains(t, p.Run(ctx), `stage "middle-source" forbids stdin`) + assert.False(t, started, "preflight validation should run before starting earlier stages") +} + +func TestInvalidStreamRequirement(t *testing.T) { + t.Parallel() + ctx := context.Background() + p := pipe.New() + p.Add(requirementStage{ + name: "invalid", + requirement: pipe.StageRequirements{ + Stdin: pipe.StreamRequirement(99), + }, + }) + require.ErrorContains(t, p.Run(ctx), `stdin: invalid stream requirement 99`) +} + +func TestFunctionNoInput(t *testing.T) { + t.Parallel() + ctx := context.Background() + p := pipe.New() + p.Add(pipe.Function("read-all", func(_ context.Context, _ pipe.Env, stdin io.Reader, _ io.Writer) error { + n, err := io.Copy(io.Discard, stdin) + assert.Equal(t, int64(0), n) + return err + })) + assert.NoError(t, p.Run(ctx)) +} + func TestErrors(t *testing.T) { t.Parallel() ctx := context.Background() @@ -903,11 +1270,8 @@ func TestErrors(t *testing.T) { func BenchmarkSingleProgram(b *testing.B) { ctx := context.Background() - - dir := b.TempDir() - for i := 0; i < b.N; i++ { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("true"), ) @@ -917,11 +1281,8 @@ func BenchmarkSingleProgram(b *testing.B) { func BenchmarkTenPrograms(b *testing.B) { ctx := context.Background() - - dir := b.TempDir() - for i := 0; i < b.N; i++ { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("echo", "hello world"), pipe.Command("cat"), @@ -943,16 +1304,13 @@ func BenchmarkTenPrograms(b *testing.B) { func BenchmarkTenFunctions(b *testing.B) { ctx := context.Background() - - dir := b.TempDir() - cp := func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { _, err := io.Copy(stdout, stdin) return err } for i := 0; i < b.N; i++ { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Println("hello world"), pipe.Function("copy1", cp), @@ -974,16 +1332,13 @@ func BenchmarkTenFunctions(b *testing.B) { func BenchmarkTenMixedStages(b *testing.B) { ctx := context.Background() - - dir := b.TempDir() - cp := func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { _, err := io.Copy(stdout, stdin) return err } for i := 0; i < b.N; i++ { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("echo", "hello world"), pipe.Function("copy1", cp), @@ -1003,6 +1358,97 @@ func BenchmarkTenMixedStages(b *testing.B) { } } +func BenchmarkMoreDataUnbuffered(b *testing.B) { + ctx := context.Background() + + cp := func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { + _, err := io.Copy(stdout, stdin) + return err + } + + for i := 0; i < b.N; i++ { + count := 0 + p := pipe.New() + p.Add( + pipe.Function( + "seq", + func(_ context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error { + for i := 1; i <= 100000; i++ { + fmt.Fprintln(stdout, i) + } + return nil + }, + ), + pipe.Command("cat"), + pipe.Function("copy2", cp), + pipe.Command("cat"), + pipe.Function("copy3", cp), + pipe.Command("cat"), + pipe.Function("copy4", cp), + pipe.Command("cat"), + pipe.Function("copy5", cp), + pipe.Command("cat"), + pipe.LinewiseFunction( + "count", + func(_ context.Context, _ pipe.Env, _ []byte, _ *bufio.Writer) error { + count++ + return nil + }, + ), + ) + err := p.Run(ctx) + if assert.NoError(b, err) { + assert.EqualValues(b, 100000, count) + } + } +} + +func BenchmarkMoreDataBuffered(b *testing.B) { + ctx := context.Background() + + cp := func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { + _, err := io.Copy(stdout, stdin) + return err + } + + for i := 0; i < b.N; i++ { + count := 0 + p := pipe.New() + p.Add( + pipe.Function( + "seq", + func(_ context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error { + out := bufio.NewWriter(stdout) + for i := 1; i <= 1000000; i++ { + fmt.Fprintln(out, i) + } + return out.Flush() + }, + ), + pipe.Command("cat"), + pipe.Function("copy2", cp), + pipe.Command("cat"), + pipe.Function("copy3", cp), + pipe.Command("cat"), + pipe.Function("copy4", cp), + pipe.Command("cat"), + pipe.Function("copy5", cp), + pipe.Command("cat"), + pipe.LinewiseFunction( + "count", + func(_ context.Context, _ pipe.Env, _ []byte, _ *bufio.Writer) error { + count++ + return nil + }, + ), + ) + err := p.Run(ctx) + if assert.NoError(b, err) { + assert.EqualValues(b, 1000000, count) + } + } +} + func genErr(err error) pipe.StageFunc { return func(_ context.Context, _ pipe.Env, _ io.Reader, _ io.Writer) error { return err diff --git a/pipe/print.go b/pipe/print.go index 766418d..559b639 100644 --- a/pipe/print.go +++ b/pipe/print.go @@ -13,6 +13,7 @@ func Print(a ...interface{}) Stage { _, err := fmt.Fprint(stdout, a...) return err }, + ForbidStdin(), ) } @@ -23,6 +24,7 @@ func Println(a ...interface{}) Stage { _, err := fmt.Fprintln(stdout, a...) return err }, + ForbidStdin(), ) } @@ -33,5 +35,6 @@ func Printf(format string, a ...interface{}) Stage { _, err := fmt.Fprintf(stdout, format, a...) return err }, + ForbidStdin(), ) } diff --git a/pipe/scanner.go b/pipe/scanner.go index b56b58c..43fd8f4 100644 --- a/pipe/scanner.go +++ b/pipe/scanner.go @@ -56,12 +56,8 @@ func ScannerFunction( return err } } - if err := scanner.Err(); err != nil { - return err - } - - return nil - // `p.AddFunction()` arranges for `stdout` to be closed. + return scanner.Err() + // The Function stage closes `stdout` if it owns it. }, ) } diff --git a/pipe/stage.go b/pipe/stage.go index f3d74d9..0e88d90 100644 --- a/pipe/stage.go +++ b/pipe/stage.go @@ -2,33 +2,90 @@ package pipe import ( "context" - "io" ) -// Stage is an element of a `Pipeline`. +// Stage is an element of a `Pipeline`. It reads from standard input +// and writes to standard output. +// +// # Who closes stdin and stdout? +// +// A `Stage` is responsible for calling `Close()` on the +// `InputStream`/`OutputStream` that represent its stdin and stdout as +// soon as it doesn't need them anymore. That responsibility begins as +// soon as the stage's `Start()` method is called, and applies +// regardless of whether `Start()` returns an error. It must close the +// streams before its `Wait()` method returns. The caller must not +// close the streams after calling `Start()`. +// +// Closing stdin/stdout tells the previous/next stage that this stage +// is done reading/writing data, which can affect their behavior. +// Therefore, it is important for a stage to close each one as soon as +// it is done with it. +// +// From the point of view of the pipeline as a whole, if stdin is +// provided by the user (`WithStdin()`), then we don't want the first +// stage to close it at all. This is arranged by passing a +// non-closing `InputStream` when it starts that stage. For stdout, it +// depends on whether the user supplied it using `WithStdout()` or +// `WithStdoutCloser()`, and in the former case provides the last +// stage with a non-closing `OutputStream`. Calling `Close()` on a +// non-closing stream (or even on a nil stream) is a NOP, so the +// `Stage` can always call `Close()` and doesn't have to worry about +// whether a stdin/stdout stream is non-closing. type Stage interface { // Name returns the name of the stage. Name() string + // Requirements returns this stage's requirements regarding how its + // stdin and stdout pipes should be created. + Requirements() StageRequirements + // Start starts the stage in the background, in the environment - // described by `env`, and using `stdin` as input. (`stdin` should - // be set to `nil` if the stage is to receive no input, which - // might be the case for the first stage in a pipeline.) It - // returns an `io.ReadCloser` from which the stage's output can be - // read (or `nil` if it generates no output, which should only be - // the case for the last stage in a pipeline). It is the stages' - // responsibility to close `stdin` (if it is not nil) when it has - // read all of the input that it needs, and to close the write end - // of its output reader when it is done, as that is generally how - // the subsequent stage knows that it has received all of its - // input and can finish its work, too. + // described by `opts.Env`, using `stdin` to provide its input and + // `stdout` to collect its output. (`stdin.Reader()` or + // `stdout.Writer()` might be `nil` if the stage is to receive no + // input or produce no output, which might be the case for the + // first/last stage in a pipeline.) The stage is responsible for + // calling `stdin.Close()` and `stdout.Close()`, even if `Start()` + // returns an error. See the `Stage` type comment for more + // information about responsibility for closing stdin and stdout. // // If `Start()` returns without an error, `Wait()` must also be - // called, to allow all resources to be freed. - Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) + // called, to allow all resources to be freed. If `Start()` returns + // an error, `Wait()` must not be called. + Start( + ctx context.Context, opts StageOptions, + stdin *InputStream, stdout *OutputStream, + ) error // Wait waits for the stage to be done, either because it has // finished or because it has been killed due to the expiration of // the context passed to `Start()`. Wait() error } + +// StageOptions carries everything (other than `ctx`, `stdin`, and +// `stdout`) that a pipeline passes to `Stage.Start`. +type StageOptions struct { + // Env is the environment (working directory and extra environment + // variables) that the stage should run in. + Env + + // PanicHandler, if non-nil, is invoked to recover a panic that escapes + // user code that a stage runs in a library-spawned goroutine, + // converting it into an error. Stage types that don't run user code in + // a library-spawned goroutine ignore it. + PanicHandler StagePanicHandler +} + +// StagePanicHandler is a function that handles panics in a pipeline's stages. +type StagePanicHandler func(p any) error + +// StageRequirements describes what a Stage needs from the streams +// connected to its stdin and stdout. The zero value is correct for +// stages that are happy with arbitrary io.Reader/io.Writer streams, +// such as Function stages. +type StageRequirements struct { + Stdin StreamRequirement + Stdout StreamRequirement +} diff --git a/pipe/stage_joiner.go b/pipe/stage_joiner.go new file mode 100644 index 0000000..03569cb --- /dev/null +++ b/pipe/stage_joiner.go @@ -0,0 +1,133 @@ +package pipe + +import ( + "errors" + "fmt" + "io" + "os" +) + +// stageJoiner is a helper type that helps join two adjacent stages +// together. stageJoiners[i] tells how to connect stage `i-1` to stage +// `i`. From the point of view of stages, `stageJoiners[i].nextStdin` +// and `stageJoiners[i+1].prevStdout` are the input and output +// streams, respectively, of `stage[i]`. The first and last elements +// of `stageJoiners` manage `p.stdin` and `p.stdout`, respectively. +// Schematically, the data flows through like this: +// +// p.stdin == stageJoiners[0].nextStdin → +// stage[0] → +// stageJoiners[1].prevStdout → stageJoiners[1].nextStdin → +// stage[1] → +// stageJoiners[2].prevStdout → stageJoiners[2].nextStdin → +// stage[2] → +// ... → +// stageJoiners[i].prevStdout → stageJoiners[i].nextStdin → +// stage[i] → +// stageJoiners[i+1].prevStdout → stageJoiners[i+1].nextStdin → +// ... → +// stageJoiners[len(stages)-1].prevStdout → stageJoiners[len(stages)-1].nextStdin → +// stage[len(stages)-1] → +// stageJoiners[len(stages)].prevStdout == p.stdout +// +// In pseudo-Shell notation, the stages are run like this: +// +// stage[0] stageJoiners[1].prevStdout +// stage[1] stageJoiners[2].prevStdout +// stage[2] stageJoiners[3].prevStdout +// ... +// stage[i] stageJoiners[i].prevStdout +// ... +// stage[len(stages)-1] p.stdout +type stageJoiner struct { + // prevStage holds the stage that needs to write to the pipe. + prevStage Stage + + // prevStageReq caches `prevStage.Requirements()` so that it + // doesn't have to be recomputed. It is the zero value if + // `prevStage` is nil. + prevStageReq StageRequirements + + // prevStdout will be used as the stdout of `prevStage`. It is + // usually the "write" end of the `(nextStdin, prevStdout)` pipe + // pair, with the connected pipe ends in the same `stageJoiner` + // instance. + prevStdout *OutputStream + + // nextStage holds the stage that needs to read from the pipe. + nextStage Stage + + // nextStageReq caches `nextStage.Requirements()` so that it + // doesn't have to be recomputed. It is the zero value if + // `nextStage` is nil. + nextStageReq StageRequirements + + // nextStdin will be used as the stdin of `nextStage`. It is + // usually the "read" end of the `(nextStdin, prevStdout)` pipe + // pair. + nextStdin *InputStream +} + +// needFilePipe returns `true` if the pipe that joins the two adjacent +// stages should be an `os.Pipe()` rather than an `io.Pipe()`. +func (sj *stageJoiner) needFilePipe() bool { + return sj.prevStageReq.Stdout == StreamPreferFile || + sj.nextStageReq.Stdin == StreamPreferFile +} + +func (sj *stageJoiner) createPipe() error { + var r io.ReadCloser + var w io.WriteCloser + if sj.needFilePipe() { + var err error + r, w, err = os.Pipe() + if err != nil { + return fmt.Errorf("creating os.Pipe: %w", err) + } + } else { + r, w = io.Pipe() + } + + sj.prevStdout = ClosingOutput(w) + sj.nextStdin = ClosingInput(r) + + return nil +} + +// closePipe closes both ends of the pipe that was allocated by +// `createPipe()`. This should only be called if the corresponding +// stage's `Start()` method was never called (otherwise the stage is +// responsible for closing its stdin and stdout). +func (sj *stageJoiner) closePipe() error { + return errors.Join( + sj.prevStdout.Close(), + sj.nextStdin.Close(), + ) +} + +// validate verifies that the adjacent stages' stream requirements are +// satisfiable, in particular that a stage that forbids its stdin or +// stdout is not connected to anything. +func (sj *stageJoiner) validate() error { + // `prevStage`'s stdout is connected if there is a `nextStage` to + // consume it (in which case an inner pipe will be created) or if + // a stream (`p.stdout`) has already been stored in `prevStdout`. + if sj.prevStage != nil && sj.prevStageReq.Stdout == StreamForbidden && + (sj.nextStage != nil || sj.prevStdout != nil) { + return fmt.Errorf( + "stage %q forbids stdout, but stdout is connected", sj.prevStage.Name(), + ) + } + + // `nextStage`'s stdin is connected if there is a `prevStage` to + // produce it (in which case an inner pipe will be created) or if + // a stream (`p.stdin`) has already been stored in `nextStdin`. + if sj.nextStage != nil && sj.nextStageReq.Stdin == StreamForbidden && + (sj.prevStage != nil || sj.nextStdin != nil) { + return fmt.Errorf( + "stage %q forbids stdin, but stdin is connected", sj.nextStage.Name(), + ) + } + + return nil +} diff --git a/pipe/stream_requirement.go b/pipe/stream_requirement.go new file mode 100644 index 0000000..ddff829 --- /dev/null +++ b/pipe/stream_requirement.go @@ -0,0 +1,38 @@ +package pipe + +import "fmt" + +// StreamRequirement describes a `Stage`'s requirement for its stdin +// or stdout, namely whether it can be anything, whether it should +// preferably be an `*os.File`, or whether it must be `nil`. The zero +// value `StreamAcceptAny` is a valid value that indicates that the +// stage has no particular requirements or preferences for its +// stdin/stdout, such as a typical `Function` stage. +type StreamRequirement int + +const ( + // StreamAcceptAny indicates that the stage hasn't declared what + // kind of stream it requires, maybe even `nil`. + StreamAcceptAny StreamRequirement = iota + + // StreamPreferFile indicates that the stage prefers the + // corresponding stream to be backed by an `*os.File` (a real file + // descriptor), but it can work with any io.Reader/io.Writer. + StreamPreferFile + + // StreamForbidden indicates that the stage requires the + // corresponding stream to be nil. It won't read/write the stream + // or close it. + StreamForbidden +) + +// Validate checks that `req` has a valid value and returns an error +// otherwise. +func (requirement StreamRequirement) Validate() error { + switch requirement { + case StreamAcceptAny, StreamPreferFile, StreamForbidden: + return nil + default: + return fmt.Errorf("invalid stream requirement %d", requirement) + } +} diff --git a/pipe/streams.go b/pipe/streams.go new file mode 100644 index 0000000..8ddd578 --- /dev/null +++ b/pipe/streams.go @@ -0,0 +1,142 @@ +package pipe + +import ( + "io" + "sync" +) + +// InputStream represents `stdin` for a stage, which might or might +// not need to be closed when the stage is done with it. It usually +// holds an `io.Reader`, which can be retrieved using `Reader()`. Its +// `Close()` method closes the reader if necessary (i.e., if the +// `InputStream` was constructed using `ClosingInput()`. The +// `Close()` method is idempotent. +// +// A nil `*InputStream` is a valid value. Its `Reader()` method +// returns `nil` and `Close()` does nothing successfully. +// +// It might seem like `InputStream` should implement `io.Reader` +// itself. But we want to avoid hiding the dynamic type of the +// `io.Reader` that is being used as the stdin of a pipeline. That +// object might be of a type that is subject to optimizations that +// aren't available for a generic `io.Reader`. For example, it might +// be an `*os.File` (which can be passed directly to subcommands or to +// `splice(2)`), or it might implement `io.WriterTo`. +type InputStream struct { + reader io.Reader + + // once is used to ensure that `Close()` is only called once. + once sync.Once + + // closer is set to `nil` after the first call to `Close()`. + closer io.Closer + + // closeErr is set to the error returned by the first call to + // `Close()`, and returned from that and any subsequent calls to + // `Close()`. + closeErr error +} + +// The stage may read from r but must not close it. +func Input(r io.Reader) *InputStream { + return &InputStream{reader: r} +} + +// The stage is responsible for closing r. +func ClosingInput(r io.ReadCloser) *InputStream { + return &InputStream{reader: r, closer: r} +} + +func (s *InputStream) Reader() io.Reader { + if s == nil { + return nil + } + return s.reader +} + +// Close closes the underlying reader if necessary. If `s` was +// constructed using `ClosingInput()`, then close the `io.ReadCloser` +// that was passed to that function. If `s` is `nil` or was +// constructed using `Input()`, then do nothing successfully. +func (s *InputStream) Close() error { + if s == nil { + return nil + } + + s.once.Do(func() { + if s.closer != nil { + s.closeErr = s.closer.Close() + s.closer = nil + } + }) + + return s.closeErr +} + +// OutputStream represents `stdout` for a stage, which might or might +// not need to be closed when the stage is done with it. It usually +// holds an `io.Writer`, which can be retrieved using `Writer()`. Its +// `Close()` method closes the writer if necessary (i.e., if the +// `OutputStream` was constructed using `ClosingOutput()`. The +// `Close()` method is idempotent. +// +// A nil `*OutputStream` is a valid value. Its `Writer()` method +// returns `nil` and `Close()` does nothing successfully. +// +// It might seem like `OutputStream` should implement `io.Writer` +// itself. But we want to avoid hiding the dynamic type of the +// `io.Writer` that is being used as the stdout of a pipeline. That +// object might be of a type that is subject to optimizations that +// aren't available for a generic `io.Writer`. For example, it might +// be an `*os.File` (which can be passed directly to subcommands or to +// `splice(2)`), or it might implement `io.ReaderFrom`. +type OutputStream struct { + writer io.Writer + + // once is used to ensure that `Close()` is only called once. + once sync.Once + + // closer is set to `nil` after the first call to `Close()`. + closer io.Closer + + // closeErr is set to the error returned by the first call to + // `Close()`, and returned from that and any subsequent calls to + // `Close()`. + closeErr error +} + +// The stage may write to w but must not close it. +func Output(w io.Writer) *OutputStream { + return &OutputStream{writer: w} +} + +// The stage is responsible for closing w. +func ClosingOutput(w io.WriteCloser) *OutputStream { + return &OutputStream{writer: w, closer: w} +} + +func (s *OutputStream) Writer() io.Writer { + if s == nil { + return nil + } + return s.writer +} + +// Close closes the underlying writer if necessary. If `s` was +// constructed using `ClosingOutput()`, then close the +// `io.WriteCloser` that was passed to that function. If `s` is `nil` +// or was constructed using `Output()`, then do nothing successfully. +func (s *OutputStream) Close() error { + if s == nil { + return nil + } + + s.once.Do(func() { + if s.closer != nil { + s.closeErr = s.closer.Close() + s.closer = nil + } + }) + + return s.closeErr +}