Skip to content

Testing Guide

At Optakt, we consistently write tests to ensure a reliable engineering environment where quality is paramount. Over the course of the product development life cycle, testing saves time and money, and helps developers write better code, more efficiently. Untested code is fragile, difficult to maintain and becomes questionable as soon as changes are made.

This guide assumes that you are already familiar with Go testing.

External projects

This guide assumes that you are dealing with an Optakt project. When working on projects for clients, they might have different testing practices. In such a case, first make sure to keep your tests consistent with their own, and if you have any doubts please voice them on Slack. If the tests of the original project are really problematic, we can discuss refactoring them and changing the way their tests are written in the future to improve code quality and reduce technical debt.

Unit Tests

This section outlines a few of the rules we try to follow when it comes to Go unit tests.

Naming Conventions

Unit tests should have consistent names. The best way to go about it is to follow the official guidelines of the Go testing package, which states that:

The naming convention to declare tests for the package, a function F, a type T and method M on type T are:

  func Test() { ... }
  func TestF() { ... }
  func TestT() { ... }
  func TestT_M() { ... }

You might sometimes want to have multiple test functions for a single method, in which case the tests should be prefixed as defined above, and followed with the purpose of the test:

func TestMyType_MethodDoesSomethingWhenX(t *testing.T) { ... }
func TestMyType_MethodDoesHandleXFailure(t *testing.T) { ... }
// And so on.

When it comes to subtests, the names of individual subtests should be lowercased and concise. The tests usually start with a subtest called nominal case which verifies that the tested component behaves as expected in a baseline situation, where no failures occur and no edge cases are handled. Subsequent tests should follow the paths through which the function can flow from top to bottom.

With the following example:

    // server.go
    func (s *Server) GetHeightForBlock(_ context.Context, req *GetHeightForBlockRequest) (*GetHeightForBlockResponse, error) {

        err := s.validate.Struct(req)
        if err != nil {
            return nil, fmt.Errorf("bad request: %w", err)
        }

        blockID := flow.HashToID(req.BlockID)
        height, err := s.index.HeightForBlock(blockID)
        if err != nil {
            return nil, fmt.Errorf("could not get height for block: %w", err)
        }

        res := GetHeightForBlockResponse{
            BlockID: req.BlockID,
            Height:  height,
        }

        return &res, nil
    }

There are three possible paths through which this function can be traversed:

  • The nominal case, where s.validate.Struct and s.index.HeightForBlock return no errors and the function returns a valid response.
  • The case where s.validate.Struct returns an error.
  • The case where s.index.HeightForBlock returns an error.

And here is what the tests should look like:

    // server_internal_test.go
    func TestServer_GetHeightForBlock(t *testing.T) {
        blockID := mocks.GenericHeader.ID()
        tests := []struct {
            name string
            req *GetHeightForBlockRequest
            mockErr error
            wantBlockID flow.Identifier
            checkErr require.ErrorAssertionFunc
        }{
            {
                name: "nominal case",
                req: &GetHeightForBlockRequest{
                    BlockID: mocks.ByteSlice(blockID),
                },
                mockErr: nil, 
                wantBlockID: blockID,
                checkErr: require.NoError,
            },
            {
                name: "handles invalid block request",
                req: &GetHeightForBlockRequest{},
                checkErr: require.Error,
            },
            {
                name: "handles failure to retrieve block height from index",
                req: &GetHeightForBlockRequest{
                    BlockID: mocks.ByteSlice(blockID),
                },
                mockErr: mocks.GenericError,
                checkErr: require.Error,
            },
        }

        for _, test := range tests {
            test := test
            t.Run(test.name, func (t *testing.T) {
                t.Parallel()

                index := mocks.BaselineReader(t)
                index.HeightForBlockFunc = func(blockID flow.Identifier) (uint64, error) {
                    return mocks.GenericHeight, test.mockErr
                }

                s := Server{
                    index:    index,
                    validate: validator.New(),
                }

                got, err := s.GetHeightForBlock(context.Background(), test.req)

                test.checkErr(t, err)
                if err == nil {
                    assert.Equal(t, mocks.GenericHeight, got.Height)
                    assert.Equal(t, test.wantBlockID[:], got.BlockID)
                }
            })
        }
    }

Internal Unit Tests

In most cases, packages can be tested using external tests only. When writing tests for a package called xyz, the external tests should be in the same folder, but in a package called xyz_test. This case is handled by Go natively and will therefore not result in complaints about there being two different packages within the same directory. Of course, there are exceptions. If you need to test some internal logic, those tests must be in a file suffixed with _internal_test.go.

Mocks

When it comes to mocking dependencies for tests, we prefer to use simple hand-made mocks rather than to use testing frameworks to generate them. The Go language makes it easy to do so elegantly by creating structures that implement the interfaces for dependencies of the tested code and exposing functions that match the interface's signature as attributes that can be overridden externally.

    package mocks

    import (
        "testing"

        "github.com/onflow/flow-go/ledger/complete/mtrie/trie"
    )

    type Loader struct {
        TrieFunc func() (*trie.MTrie, error)
    }

    func BaselineLoader(t *testing.T) *Loader {
        t.Helper()

        l := Loader{
            TrieFunc: func() (*trie.MTrie, error) {
                return GenericTrie, nil
            },
        }

        return &l
    }

    func (l *Loader) Trie() (*trie.MTrie, error) {
        return l.TrieFunc()
    }

Using those mocks is as simple as instantiating a baseline version of the mock and setting its attributes to the desired functions:

    // ...
    t.Run("handles failure to load checkpoint", func(t *testing.T) {
        t.Parallel()

        tr, st := baselineFSM(t, StatusEmpty)

        load := mocks.BaselineLoader(t)
        load.CheckpointFunc = func() (*trie.MTrie, error) {
            return nil, mocks.GenericError
        }

        tr.load = load

        err := tr.BootstrapState(st)
        assert.Error(t, err)
    })
    // ...

When the functions you mock take parameters that are consistent based on inputs, you should, in the nominal case tests, make assertions on those values as well. For error tests, simply ignore those values in the mock.

    // ...
    t.Run("nominal case", func(t *testing.T) {
        t.Parallel()

        mock := mocks.BaselineTransport(t)
        mock.Request = func(method, url string) (*http.Response, error) {
            assert.Equal(t, http.MethodPost, method)
            assert.Equal(t, "https://google.com", url)

            return testResponse, nil
        }

        subject := NewServer(mock)

        got, err := subject.Call()
        require.NoError(t, err)
        assert.NotNil(t, got)
        // Assertions...
    })
    t.Run("handles failure to make request", func(t *testing.T) {
        t.Parallel()

        mock := mocks.BaselineTransport(t)
        mock.Request = func(string, string) (*http.Response, error) {
            return nil, mocks.GenericError()
        }

        subject := NewServer(mock)

        got, err := subject.Call()
        assert.Error(t, err)
    })
    // ...

Mocks Package

When the same mock needs to be used within multiple packages, it can make sense to create a package to store common mocks. This package should be located at ./testing/mocks from the root of the repository, and each file within it should define a single mock.

Pseudorandom Generic Values

When using test data for unit tests, it is often a good idea to use random generated data as the inputs. This avoids the bias where a test passes because it is given a valid set of inputs while some other inputs might have highlighted a flaw in the logic, by using an unconstrained data set.

In order for the tests to be repeatable and for results to be consistent though, the given inputs should not be completely random, but instead they should be pseudorandom, with the same initial seed, to ensure the same sequence of "random" tests.

Here is an example of such a value being generated.

    func GenericAddresses(number int) []flow.Address {
        // Ensure consistent deterministic results.
        random := rand.New(rand.NewSource(5))

        var addresses []flow.Address
        for i := 0; i < number; i++ {
            var address flow.Address
            binary.BigEndian.PutUint64(address[0:], random.Uint64())

            addresses = append(addresses, address)
        }

        return addresses
    }

    func GenericAddress(index int) flow.Address {
        return GenericAddresses(index + 1)[index]
    }

Warning

While randomly generating valid inputs makes sense, randomly generating invalid inputs does not. In the case of invalid inputs, it is much better to have an exhaustive list of all types of cases that are expected to be invalid and always test each one of them.

Generic Values Package

When the same generic value need to be used within multiple packages, it can make sense to create a package to store common generic values. This package should be located at ./testing/mocks from the root of the repository, and values should be in a file called generic.go.

Parallelization

Since the version 1.7 of Go, tests can be run in parallel. This can be done by calling t.Parallel in each subtest. Calling this function signals that the test is to be run in parallel with (and only with) other parallel tests, and the amount of tests running in parallel is limited by the value of runtime.GOMAXPROCS.

There are multiple advantages to parallelizing tests:

  • It ensures that regardless of the order in which inputs are given, components behave as expected.
  • It maximizes performance, which in turns results in a faster CI and a faster workflow for everyone, which allows us to write more tests and therefore produce more actionable data to find bugs as well as improve tests cases and coverage.
  • It makes it possible to ensure that the components you test are concurrency-safe

Parallelizing table-driven tests

When it comes to table-driven tests, a common pitfall developers fall into is to call t.Parallel in their subtest without capturing the loop variable with their test case.

Here is an example of how it should be done:

func TestGroupedParallel(t *testing.T) {
    for _, tc := range tests {
        tc := tc // capture range variable
        t.Run(tc.Name, func(t *testing.T) {
            t.Parallel()
            ...
        })
    }
}

Standard testing Package

The standard testing package is very powerful, and does not require additional frameworks to be used efficiently. The only exception we make to that are the stretchr/testify/assert and stretchr/testify/require packages which we use only for convenience, as they expose assertion functions that produce consistent outputs and make tests easy to understand.

Subtests and Sub-benchmarks

The testing package exposes a Run method on the T type which makes it possible to nest tests within tests. This can be very useful, as it enables creating a hierarchical structure within a test.

index_internal_test.go
func TestIndex(t *testing.T) {
    // ...
    t.Run("collections", func(t *testing.T) {
        t.Parallel()

        collections := mocks.GenericCollections(4)

        reader, writer, db := setupIndex(t)
        defer db.Close()

        assert.NoError(t, writer.Collections(mocks.GenericHeight, collections))
        // Close the writer to make it commit its transactions.
        require.NoError(t, writer.Close())

        // NOTE: The following subtests should NOT be run in parallel, because of the deferral
        // to close the database above.
        t.Run("retrieve collection by ID", func(t *testing.T) {
            got, err := reader.Collection(collections[0].ID())

            require.NoError(t, err)
            assert.Equal(t, collections[0], got)
        })

        t.Run("retrieve collections by height", func(t *testing.T) {
            got, err := reader.CollectionsByHeight(mocks.GenericHeight)

            require.NoError(t, err)
            assert.ElementsMatch(t, mocks.GenericCollectionIDs(4), got)
        })

        t.Run("retrieve transactions from collection", func(t *testing.T) {
            // For now this index is not used.
        })
    })
    // ...
}
Structure

In those tests, usually the structure is composed of:

  • At the top-level, the common testing variables that are used in multiple subtests.
    • This can be a logger that discards logs, some test input that can be reused, expected return values shared between tests, and so on.
  • In each sub-test:
    • The test is defined as parallelized;
    • The inputs and outputs are defined in variables;
    • The mocks are defined and their methods overridden when necessary;
    • If test is for a method, the method's struct (the subject of the test) is created and given the mocks;
    • The tested method is called on the subject, its return values are often stored in got and err: got, err := mything.Do(input);
    • If the test should not error, require.NoError is called on the error return;
    • Assertions are made on the returned value.

Table-Driven Tests

It makes a lot of sense to use subtests when testing the behavior of complex components, but it is better to use table-driven tests when testing a simple function with an expected output for a given input, and that there are many cases to cover. For such cases, table-driven tests massively improve clarity and readability.

They should not be used blindly for all tests, however. In cases where the tested component is complex and that testing its methods cannot be simplified to a common setup, call to a function and assertion of the output, trying to use table-driven tests at all costs might lead to messy code, where the subtest which runs the test case is full of conditions to try to handle each separate setup. This is usually a sign that using simple tests and subtests would be a better approach.

Case Study: The Flow DPS Mapper

Sometimes, a core piece of software might seem impossible to test. That was the case for the mapper component in Flow DPS at some point, where its main function consisted of a 453-lines-long loop which orchestrated the use of all the other components of the application.

mapper_old.go
package mapper

import (
    "bytes"
    "context"
    "errors"
    "fmt"
    "os"
    "sort"
    "sync"

    "github.com/gammazero/deque"
    "github.com/rs/zerolog"

    "github.com/onflow/flow-go/ledger"
    "github.com/onflow/flow-go/ledger/complete/mtrie/flattener"
    "github.com/onflow/flow-go/ledger/complete/mtrie/node"
    "github.com/onflow/flow-go/ledger/complete/mtrie/trie"
    "github.com/onflow/flow-go/ledger/complete/wal"
    "github.com/onflow/flow-go/model/flow"

    "github.com/optakt/flow-dps/models/dps"
    "github.com/optakt/flow-dps/models/index"
)

type Mapper struct {
    log zerolog.Logger
    cfg Config

    chain Chain
    feed  Feeder
    index index.Writer

    wg   *sync.WaitGroup
    stop chan struct{}
}

// New creates a new mapper that uses chain data to map trie updates to blocks
// and then passes on the details to the indexer for indexing.
func New(log zerolog.Logger, chain Chain, feed Feeder, index index.Writer, options ...func(*Config)) (*Mapper, error) {

    // We don't use a checkpoint by default. The options can set one, in which
    // case we will add the checkpoint as a finalized state commitment in our
    // trie registry.
    cfg := Config{
        CheckpointFile: "",
        PostProcessing: PostNoop,
    }
    for _, option := range options {
        option(&cfg)
    }

    // Check if the checkpoint file exists.
    if cfg.CheckpointFile != "" {
        stat, err := os.Stat(cfg.CheckpointFile)
        if err != nil {
            return nil, fmt.Errorf("invalid checkpoint file: %w", err)
        }
        if stat.IsDir() {
            return nil, fmt.Errorf("invalid checkpoint file: directory")
        }
    }

    i := Mapper{
        log:   log,
        chain: chain,
        feed:  feed,
        index: index,
        cfg:   cfg,
        wg:    &sync.WaitGroup{},
        stop:  make(chan struct{}),
    }

    return &i, nil
}

func (m *Mapper) Stop(ctx context.Context) error {
    close(m.stop)
    done := make(chan struct{})
    go func() {
        m.wg.Wait()
        close(done)
    }()
    select {
    case <-ctx.Done():
        return ctx.Err()
    case <-done:
        return nil
    }
}

// NOTE: We might want to move height and tree (checkpoint) to parameters of the
// run function; that would make it quite easy to resume from an arbitrary
// point in the LedgerWAL and get rid of the related struct fields.

func (m *Mapper) Run() error {
    m.wg.Add(1)
    defer m.wg.Done()

    // We start trying to map at the root height.
    height, err := m.chain.Root()
    if err != nil {
        return fmt.Errorf("could not get root height: %w", err)
    }

    // We always initialize an empty state trie to refer to the first step
    // before the checkpoint. If there is no checkpoint, then the step after the
    // checkpoint will also just be the empty trie. Otherwise, the second trie
    // will load the checkpoint trie.
    empty := trie.NewEmptyMTrie()
    var tree *trie.MTrie
    if m.cfg.CheckpointFile == "" {
        tree = empty
    } else {
        m.log.Info().Msg("checkpoint rebuild started")
        file, err := os.Open(m.cfg.CheckpointFile)
        if err != nil {
            return fmt.Errorf("could not open checkpoint file: %w", err)
        }
        checkpoint, err := wal.ReadCheckpoint(file)
        if err != nil {
            return fmt.Errorf("could not read checkpoint: %w", err)
        }
        trees, err := flattener.RebuildTries(checkpoint)
        if err != nil {
            return fmt.Errorf("could not rebuild tries: %w", err)
        }
        if len(trees) != 1 {
            return fmt.Errorf("should only have one trie in root checkpoint (tries: %d)", len(trees))
        }
        tree = trees[0]
        m.log.Info().Msg("checkpoint rebuild finished")
    }

    m.log.Info().Msg("path collection started")

    // We have to index all of the paths from the checkpoint; otherwise, we will
    // miss every single one of the bootstrapped registers.
    paths := make([]ledger.Path, 0, len(tree.AllPayloads()))
    queue := deque.New()
    root := tree.RootNode()
    if root != nil {
        queue.PushBack(root)
    }
    for queue.Len() > 0 {
        node := queue.PopBack().(*node.Node)
        if node.IsLeaf() {
            path := node.Path()
            paths = append(paths, *path)
            continue
        }
        if node.LeftChild() != nil {
            queue.PushBack(node.LeftChild())
        }
        if node.RightChild() != nil {
            queue.PushBack(node.RightChild())
        }
    }

    m.log.Info().Int("paths", len(paths)).Msg("path collection finished")

    m.log.Info().Msg("path sorting started")

    sort.Slice(paths, func(i int, j int) bool {
        return bytes.Compare(paths[i][:], paths[j][:]) < 0
    })

    m.log.Info().Msg("path sorting finished")

    // When trying to go from one finalized block to the next, we keep a list
    // of intermediary tries until the full set of transitions have been
    // identified. We keep track of these transitions as steps in this map.
    steps := make(map[flow.StateCommitment]*Step)

    // We start at an "imaginary" step that refers to an empty trie, has no
    // paths and no previous commit. We consider this step already done, so it
    // will never be indexed; it's merely used as the sentinel value for
    // stopping when we index the first block. It also makes sure that we don't
    // return a `nil` trie if we abort indexing before the first block is done.
    emptyCommit := flow.DummyStateCommitment
    steps[emptyCommit] = &Step{
        Commit: flow.StateCommitment{},
        Paths:  nil,
        Tree:   empty,
    }

    // We then add a second step that refers to the first step that is already
    // done, which uses the commit of the initial state trie after the
    // checkpoint has been loaded, and contains all of the paths found in the
    // initial checkpoint state trie. This will make sure that we index all the
    // data from the checkpoint as part of the first block.
    rootCommit := flow.StateCommitment(tree.RootHash())
    steps[rootCommit] = &Step{
        Commit: emptyCommit,
        Paths:  paths,
        Tree:   tree,
    }

    // This is how we let the indexing loop know that the first "imaginary" step
    // was already indexed. The `commitPrev` value is used as a sentinel value
    // for when to stop going backwards through the steps when indexing a block.
    // This means the value is always set to the last already indexed step.
    commitPrev := emptyCommit

    m.log.Info().Msg("state indexing started")

    // Next, we launch into the loop that is responsible for mapping all
    // incoming trie updates to a block. The loop itself has no concept of what
    // the next state commitment is that we should look at. It will simply try
    // to find a previous step for _any_ trie update that comes in. This means
    // that the first trie update needs to either apply to the empty trie or to
    // the trie after the checkpoint in order to be processed.
    once := &sync.Once{}
Outer:
    for {
        // We want to check in this tight loop if we want to quit, just in case
        // we get stuck on a timed out network connection.
        select {
        case <-m.stop:
            break Outer
        default:
            // keep going
        }

        log := m.log.With().
            Uint64("height", height).
            Hex("commit_prev", commitPrev[:]).Logger()

        // As a first step, we retrieve the state commitment of the finalized
        // block at the current height; we start at the root height and then
        // increase it each time we are done indexing a block. Once an applied
        // trie update gives us a state trie with the same root hash as
        // `commitNext`, we have reached the end state of the next finalized
        // block and can index all steps in-between for that block height.
        commitNext, err := m.chain.Commit(height)

        // If the retrieval times out, it's possible that we are on a live chain
        // and the next block has not been finalized yet. We should thus simply
        // retry until we have a new block.
        if errors.Is(err, dps.ErrTimeout) {
            log.Warn().Msg("commit retrieval timed out, retrying")
            continue Outer
        }

        // If we have reached the end of the finalized blocks, we are probably
        // on a historical chain and there are no more finalized blocks for the
        // related spork. We can exit without error.
        if errors.Is(err, dps.ErrFinished) {
            log.Debug().Msg("reached end of finalized chain")
            break Outer
        }

        // Any other error should not happen and should crash explicitly.
        if err != nil {
            return fmt.Errorf("could not retrieve next commit (height: %d): %w", height, err)
        }

        log = log.With().Hex("commit_next", commitNext[:]).Logger()

    Inner:
        for {
            // We want to check in this tight loop if we want to quit, just in case
            // we get stuck on a timed out network connection.
            select {
            case <-m.stop:
                break Outer
            default:
                // keep going
            }

            // When we have the state commitment of the next finalized block, we
            // check to see if we find a trie for it in our steps. If we do, it
            // means that we have steps from the last finalized block to the
            // finalized block at the current height. This condition will
            // trigger immediately for every empty block.
            _, ok := steps[commitNext]
            if ok {
                break Inner
            }

            // If we don't find a trie for the current state commitment, we need
            // to keep applying trie updates to state tries until one of them
            // does have the correct commit. We simply feed the next trie update
            // here.
            update, err := m.feed.Update()

            // Once more, we might be on a live spork and the next delta might not
            // be available yet. In that case, keep trying.
            if errors.Is(err, dps.ErrTimeout) {
                log.Warn().Msg("delta retrieval timed out, retrying")
                continue Inner
            }

            // Similarly, if no more deltas are available, we reached the end of
            // the WAL and we are done reconstructing the execution state.
            if errors.Is(err, dps.ErrFinished) {
                log.Debug().Msg("reached end of delta log")
                break Outer
            }

            // Other errors should fail execution as they should not happen.
            if err != nil {
                return fmt.Errorf("could not retrieve next delta: %w", err)
            }

            // NOTE: We used to require a copy of the `RootHash` here, when it
            // was still a byte slice, as the underlying slice was being reused.
            // It was changed to a value type that is always copied now.
            commitBefore := flow.StateCommitment(update.RootHash)

            log := log.With().Hex("commit_before", commitBefore[:]).Logger()

            // Once we have our new update and know which trie it should be
            // applied to, we check to see if we have such a trie in our current
            // steps. If not, we can simply skip it; this can happen, for
            // example, when there is an execution fork and the trie update
            // applies to an obsolete part of the blockchain history.
            step, ok := steps[commitBefore]
            if !ok {
                log.Debug().Msg("skipping trie update without matching trie")
                continue Inner
            }

            // We de-duplicate the paths and payloads here. This replicates some
            // code that is part of the execution node and has moved between
            // different layers of the architecture. We keep it to be safe for
            // all versions of the Flow dependencies.
            // NOTE: Past versions of this code required paths to be copied,
            // because the underlying slice was being re-used. In contrary,
            // deep-copying payloads was a bad idea, because they were already
            // being copied by the trie insertion code, and it would have led to
            // twice the memory usage.
            paths = make([]ledger.Path, 0, len(update.Paths))
            lookup := make(map[ledger.Path]*ledger.Payload)
            for i, path := range update.Paths {
                _, ok := lookup[path]
                if !ok {
                    paths = append(paths, path)
                }
                lookup[path] = update.Payloads[i]
            }
            sort.Slice(paths, func(i, j int) bool {
                return bytes.Compare(paths[i][:], paths[j][:]) < 0
            })
            payloads := make([]ledger.Payload, 0, len(paths))
            for _, path := range paths {
                payloads = append(payloads, *lookup[path])
            }

            // We can now apply the trie update to the state trie as it was at
            // the previous step. This is where the trie code will deep-copy the
            // payloads.
            // NOTE: It's important that we don't shadow the variable here,
            // otherwise the root trie will never go out of scope and we will
            // never garbage collect any of the root trie payloads that have
            // been replaced by subsequent trie updates.
            tree, err = trie.NewTrieWithUpdatedRegisters(step.Tree, paths, payloads)
            if err != nil {
                return fmt.Errorf("could not update trie: %w", err)
            }

            // We then store the new trie along with the state commitment of its
            // parent and the paths that were changed. This will make it
            // available for subsequent trie updates to be applied to it, and it
            // will also allow us to reconstruct the payloads changed in this
            // step by retrieving them directly from the trie with the given
            // paths.
            commitAfter := flow.StateCommitment(tree.RootHash())
            step = &Step{
                Commit: commitBefore,
                Paths:  paths,
                Tree:   tree,
            }
            steps[commitAfter] = step

            log.Debug().Hex("commit_after", commitAfter[:]).Msg("trie update applied")
        }

        // At this point we have identified a step that has lead to the state
        // commitment of the finalized block at the current height. We can
        // retrieve some additional indexing data, such as the block header and
        // the events that resulted from transactions in the block.
        header, err := m.chain.Header(height)
        if err != nil {
            return fmt.Errorf("could not retrieve header: %w (height: %d)", err, height)
        }
        events, err := m.chain.Events(height)
        if err != nil {
            return fmt.Errorf("could not retrieve events: %w (height: %d)", err, height)
        }
        transactions, err := m.chain.Transactions(height)
        if err != nil {
            return fmt.Errorf("could not retrieve transactions: %w (height: %d)", err, height)
        }
        collections, err := m.chain.Collections(height)
        if err != nil {
            return fmt.Errorf("could not retrieve collections: %w (height: %d)", err, height)
        }
        blockID := header.ID()

        // TODO: Refactor the mapper in https://github.com/optakt/flow-dps/issues/128
        // and replace naive if statements around indexing.

        // We then index the data for the finalized block at the current height.
        if m.cfg.IndexHeaders {
            err = m.index.Header(height, header)
            if err != nil {
                return fmt.Errorf("could not index header: %w", err)
            }
        }
        if m.cfg.IndexCommit {
            err = m.index.Commit(height, commitNext)
            if err != nil {
                return fmt.Errorf("could not index commit: %w", err)
            }
        }
        if m.cfg.IndexEvents {
            err = m.index.Events(height, events)
            if err != nil {
                return fmt.Errorf("could not index events: %w", err)
            }
        }
        if m.cfg.IndexBlocks {
            err = m.index.Height(blockID, height)
            if err != nil {
                return fmt.Errorf("could not index block heights: %w", err)
            }
        }
        if m.cfg.IndexTransactions {
            err = m.index.Transactions(blockID, collections, transactions)
            if err != nil {
                return fmt.Errorf("could not index transactions: %w", err)
            }
        }

        // In order to index the payloads, we step back from the state
        // commitment of the finalized block at the current height to the state
        // commitment of the last finalized block that was indexed. For each
        // step, we collect all the payloads by using the paths for the step and
        // index them as we go.
        // NOTE: We keep track of the paths for which we already indexed
        // payloads, so we can skip them in earlier steps. One inherent benefit
        // of stepping from the last step to the first step is that this will
        // automatically use only the latest update of a register, which is
        // exactly what we want.
        commit := commitNext
        updated := make(map[ledger.Path]struct{})
        for commit != commitPrev {

            // In the first part, we get the step we are currently at and filter
            // out any paths that have already been updated.
            step := steps[commit]
            paths := make([]ledger.Path, 0, len(step.Paths))
            for _, path := range step.Paths {
                _, ok := updated[path]
                if ok {
                    continue
                }
                paths = append(paths, path)
                updated[path] = struct{}{}
            }

            if !m.cfg.IndexPayloads {
                commit = step.Commit
                continue
            }

            // We then divide the remaining paths into chunks of 1000. For each
            // batch, we retrieve the payloads from the state trie as it was at
            // the end of this block and index them.
            count := 0
            n := 1000
            total := (len(paths) + n - 1) / n
            log.Debug().Int("num_paths", len(paths)).Int("num_batches", total).Msg("path batching executed")
            for start := 0; start < len(paths); start += n {
                // This loop may take a while, especially for the root checkpoint
                // updates, so check if we should quit.
                select {
                case <-m.stop:
                    break Outer
                default:
                    // keep going
                }

                end := start + n
                if end > len(paths) {
                    end = len(paths)
                }
                batch := paths[start:end]
                payloads := step.Tree.UnsafeRead(batch)
                err = m.index.Payloads(height, batch, payloads)
                if err != nil {
                    return fmt.Errorf("could not index payloads: %w", err)
                }

                count++

                log.Debug().Int("batch", count).Int("start", start).Int("end", end).Msg("path batch indexed")
            }

            // Finally, we forward the commit to the previous trie update and
            // repeat until we have stepped all the way back to the last indexed
            // commit.
            commit = step.Commit
        }

        // At this point, we can delete any trie that does not correspond to
        // the state that we have just reached. This will allow the garbage
        // collector to free up any payload that has been changed and which is
        // no longer part of the state trie at the newly indexed finalized
        // block.
        for key := range steps {
            if key != commitNext {
                delete(steps, key)
            }
        }

        // Last but not least, we take care of properly indexing the height of
        // the first indexed block and the height of the last indexed block.
        once.Do(func() { err = m.index.First(height) })
        if err != nil {
            return fmt.Errorf("could not index first height: %w", err)
        }
        err = m.index.Last(height)
        if err != nil {
            return fmt.Errorf("could not index last height: %w", err)
        }

        // We have now successfully indexed all state trie changes and other
        // data at the current height. We set the last indexed step to the last
        // step from our current height, and then increase the height to start
        // the indexing of the next block.
        commitPrev = commitNext
        height++

        log.Info().
            Hex("block", blockID[:]).
            Int("num_changes", len(updated)).
            Int("num_events", len(events)).
            Msg("block data indexed")
    }

    m.log.Info().Msg("state indexing finished")

    step := steps[commitPrev]
    m.cfg.PostProcessing(step.Tree)

    return nil
}

As it was, this code was untestable. Covering each possible case from this huge piece of logic would have required immense, complex, unreadable tests, that would break whenever a piece of this logic would change, and this would require a huge amount of maintenance effort.

To solve that massive problem, we refactored our original mapper into a finite-state machine which replicates the same computation logic by applying transitions to a state.

mapper_new.go
package mapper

import (
    "errors"
    "fmt"
    "sync"
    "time"

    "github.com/rs/zerolog"

    "github.com/onflow/flow-go/ledger"
    "github.com/onflow/flow-go/ledger/complete/mtrie/trie"
    "github.com/onflow/flow-go/model/flow"

    "github.com/optakt/flow-dps/models/dps"
)

// TransitionFunc is a function that is applied onto the state machine's
// state.
type TransitionFunc func(*State) error

// Transitions is what applies transitions to the state of an FSM.
type Transitions struct {
    cfg   Config
    log   zerolog.Logger
    load  Loader
    chain dps.Chain
    feed  Feeder
    read  dps.Reader
    write dps.Writer
    once  *sync.Once
}

// NewTransitions returns a Transitions component using the given dependencies and using the given options
func NewTransitions(log zerolog.Logger, load Loader, chain dps.Chain, feed Feeder, read dps.Reader, write dps.Writer, options ...Option) *Transitions {

    cfg := DefaultConfig
    for _, option := range options {
        option(&cfg)
    }

    t := Transitions{
        log:   log.With().Str("component", "mapper_transitions").Logger(),
        cfg:   cfg,
        load:  load,
        chain: chain,
        feed:  feed,
        read:  read,
        write: write,
        once:  &sync.Once{},
    }

    return &t
}

// InitializeMapper initializes the mapper by either going into bootstrapping or
// into resuming, depending on the configuration.
func (t *Transitions) InitializeMapper(s *State) error {
    if s.status != StatusInitialize {
        return fmt.Errorf("invalid status for initializing mapper (%s)", s.status)
    }

    if t.cfg.BootstrapState {
        s.status = StatusBootstrap
        return nil
    }

    s.status = StatusResume
    return nil
}

// BootstrapState bootstraps the state by loading the checkpoint if there is one
// and initializing the elements subsequently used by the FSM.
func (t *Transitions) BootstrapState(s *State) error {
    if s.status != StatusBootstrap {
        return fmt.Errorf("invalid status for bootstrapping state (%s)", s.status)
    }

    // We always need at least one step in our forest, which is used as the
    // stopping point when indexing the payloads since the last finalized
    // block. We thus introduce an empty tree, with no paths and an
    // irrelevant previous commit.
    empty := trie.NewEmptyMTrie()
    s.forest.Save(empty, nil, flow.DummyStateCommitment)

    // The chain indexing will forward last to next and next to current height,
    // which will be the one for the checkpoint.
    first := flow.StateCommitment(empty.RootHash())
    s.last = flow.DummyStateCommitment
    s.next = first

    t.log.Info().Hex("commit", first[:]).Msg("added empty tree to forest")

    // Then, we can load the root height and apply it to the state. That
    // will allow us to load the root blockchain data in the next step.
    height, err := t.chain.Root()
    if err != nil {
        return fmt.Errorf("could not get root height: %w", err)
    }
    s.height = height

    // When bootstrapping, the loader injected into the mapper loads the root
    // checkpoint.
    tree, err := t.load.Trie()
    if err != nil {
        return fmt.Errorf("could not load root trie: %w", err)
    }
    paths := allPaths(tree)
    s.forest.Save(tree, paths, first)

    second := tree.RootHash()
    t.log.Info().Uint64("height", s.height).Hex("commit", second[:]).Int("registers", len(paths)).Msg("added checkpoint tree to forest")

    // We have successfully bootstrapped. However, no chain data for the root
    // block has been indexed yet. This is why we "pretend" that we just
    // forwarded the state to this height, so we go straight to the chain data
    // indexing.
    s.status = StatusIndex
    return nil
}

// ResumeIndexing resumes indexing the data from a previous run.
func (t *Transitions) ResumeIndexing(s *State) error {
    if s.status != StatusResume {
        return fmt.Errorf("invalid status for resuming indexing (%s)", s.status)
    }

    // When resuming, we want to avoid overwriting the `first` height in the
    // index with the height we are resuming from. Theoretically, all that would
    // be needed would be to execute a no-op on `once`, which would subsequently
    // be skipped in the height forwarding code. However, this bug was already
    // released, so we have databases where `first` was incorrectly set to the
    // height we resume from. In order to fix them, we explicitly write the
    // correct `first` height here again, while at the same time using `once` to
    // disable any subsequent attempts to write it.
    first, err := t.chain.Root()
    if err != nil {
        return fmt.Errorf("could not get root height: %w", err)
    }
    t.once.Do(func() { err = t.write.First(first) })
    if err != nil {
        return fmt.Errorf("could not write first: %w", err)
    }

    // We need to know what the last indexed height was at the point we stopped
    // indexing.
    last, err := t.read.Last()
    if err != nil {
        return fmt.Errorf("could not get last height: %w", err)
    }

    // When resuming, the loader injected into the mapper rebuilds the trie from
    // the paths and payloads stored in the index database.
    tree, err := t.load.Trie()
    if err != nil {
        return fmt.Errorf("could not restore index trie: %w", err)
    }

    // After loading the trie, we should do a sanity check on its hash against
    // the commit we indexed for it.
    hash := flow.StateCommitment(tree.RootHash())
    commit, err := t.read.Commit(last)
    if err != nil {
        return fmt.Errorf("could not get last commit: %w", err)
    }
    if hash != commit {
        return fmt.Errorf("restored trie hash does not match last commit (hash: %x, commit: %x)", hash, commit)
    }

    // At this point, we can store the restored trie in our forest, as the trie
    // for the last finalized block. We do not need to care about the parent
    // state commitment or the paths, as they should not be used.
    s.last = flow.DummyStateCommitment
    s.next = commit
    s.forest.Save(tree, nil, flow.DummyStateCommitment)

    // Lastly, we just need to point to the next height. The chain indexing will
    // then proceed with the first non-indexed block and forward the state
    // commitments accordingly.
    s.height = last + 1

    // At this point, we should be able to start indexing the chain data for
    // the next height.
    s.status = StatusIndex
    return nil
}

// IndexChain indexes chain data for the current height.
func (t *Transitions) IndexChain(s *State) error {
    if s.status != StatusIndex {
        return fmt.Errorf("invalid status for indexing chain (%s)", s.status)
    }

    log := t.log.With().Uint64("height", s.height).Logger()

    // We try to retrieve the next header until it becomes available, which
    // means all data coming from the protocol state is available after this
    // point.
    header, err := t.chain.Header(s.height)
    if errors.Is(err, dps.ErrUnavailable) {
        log.Debug().Msg("waiting for next header")
        time.Sleep(t.cfg.WaitInterval)
        return nil
    }
    if err != nil {
        return fmt.Errorf("could not get header: %w", err)
    }

    // At this point, we can retrieve the data from the consensus state. This is
    // a slight optimization for the live indexer, as it allows us to process
    // some data before the full execution data becomes available.
    guarantees, err := t.chain.Guarantees(s.height)
    if err != nil {
        return fmt.Errorf("could not get guarantees: %w", err)
    }
    seals, err := t.chain.Seals(s.height)
    if err != nil {
        return fmt.Errorf("could not get seals: %w", err)
    }

    // We can also proceed to already indexing the data related to the consensus
    // state, before dealing with anything related to execution data, which
    // might go into the wait state.
    blockID := header.ID()
    err = t.write.Height(blockID, s.height)
    if err != nil {
        return fmt.Errorf("could not index height: %w", err)
    }
    err = t.write.Header(s.height, header)
    if err != nil {
        return fmt.Errorf("could not index header: %w", err)
    }
    err = t.write.Guarantees(s.height, guarantees)
    if err != nil {
        return fmt.Errorf("could not index guarantees: %w", err)
    }
    err = t.write.Seals(s.height, seals)
    if err != nil {
        return fmt.Errorf("could not index seals: %w", err)
    }

    // Next, we try to retrieve the next commit until it becomes available,
    // at which point all the data coming from the execution data should be
    // available.
    commit, err := t.chain.Commit(s.height)
    if errors.Is(err, dps.ErrUnavailable) {
        log.Debug().Msg("waiting for next state commitment")
        time.Sleep(t.cfg.WaitInterval)
        return nil
    }
    if err != nil {
        return fmt.Errorf("could not get commit: %w", err)
    }
    collections, err := t.chain.Collections(s.height)
    if err != nil {
        return fmt.Errorf("could not get collections: %w", err)
    }
    transactions, err := t.chain.Transactions(s.height)
    if err != nil {
        return fmt.Errorf("could not get transactions: %w", err)
    }
    results, err := t.chain.Results(s.height)
    if err != nil {
        return fmt.Errorf("could not get transaction results: %w", err)
    }
    events, err := t.chain.Events(s.height)
    if err != nil {
        return fmt.Errorf("could not get events: %w", err)
    }

    // Next, all we need to do is index the remaining data and we have fully
    // processed indexing for this block height.
    err = t.write.Commit(s.height, commit)
    if err != nil {
        return fmt.Errorf("could not index commit: %w", err)
    }
    err = t.write.Collections(s.height, collections)
    if err != nil {
        return fmt.Errorf("could not index collections: %w", err)
    }
    err = t.write.Transactions(s.height, transactions)
    if err != nil {
        return fmt.Errorf("could not index transactions: %w", err)
    }
    err = t.write.Results(results)
    if err != nil {
        return fmt.Errorf("could not index transaction results: %w", err)
    }
    err = t.write.Events(s.height, events)
    if err != nil {
        return fmt.Errorf("could not index events: %w", err)
    }

    // At this point, we need to forward the `last` state commitment to
    // `next`, so we know what the state commitment was at the last finalized
    // block we processed. This will allow us to know when to stop when
    // walking back through the forest to collect trie updates.
    s.last = s.next

    // Last but not least, we need to update `next` to point to the commit we
    // have just retrieved for the new block height. This is the sentinel that
    // tells us when we have collected enough trie updates for the forest to
    // have reached the next finalized block.
    s.next = commit

    log.Info().Msg("indexed blockchain data for finalized block")

    // After indexing the blockchain data, we can go back to updating the state
    // tree until we find the commit of the finalized block. This will allow us
    // to index the payloads then.
    s.status = StatusUpdate
    return nil
}

// UpdateTree updates the state's tree. If the state's forest already matches with the next block's state commitment,
// it immediately returns and sets the state's status to StatusMatched.
func (t *Transitions) UpdateTree(s *State) error {
    if s.status != StatusUpdate {
        return fmt.Errorf("invalid status for updating tree (%s)", s.status)
    }

    log := t.log.With().Uint64("height", s.height).Hex("last", s.last[:]).Hex("next", s.next[:]).Logger()

    // If the forest contains a tree for the commit of the next finalized block,
    // we have reached our goal, and we can go to the next step in order to
    // collect the register payloads we want to index for that block.
    ok := s.forest.Has(s.next)
    if ok {
        log.Info().Hex("commit", s.next[:]).Msg("matched commit of finalized block")
        s.status = StatusCollect
        return nil
    }

    // First, we get the next tree update from the feeder. We can skip it if
    // it doesn't have any updated paths, or if we can't find the tree to apply
    // it to in the forest. This usually means that it was meant for a pruned
    // branch of the execution forest.
    update, err := t.feed.Update()
    if errors.Is(err, dps.ErrUnavailable) {
        time.Sleep(t.cfg.WaitInterval)
        log.Debug().Msg("waiting for next trie update")
        return nil
    }
    if err != nil {
        return fmt.Errorf("could not feed update: %w", err)
    }
    parent := flow.StateCommitment(update.RootHash)
    tree, ok := s.forest.Tree(parent)
    if !ok {
        log.Warn().Msg("state commitment mismatch, retrieving next trie update")
        return nil
    }

    // We then apply the update to the relevant tree, as retrieved from the
    // forest, and save the updated tree in the forest. If the tree is not new,
    // we should error, as that should not happen.
    paths, payloads := pathsPayloads(update)
    tree, err = trie.NewTrieWithUpdatedRegisters(tree, paths, payloads)
    if err != nil {
        return fmt.Errorf("could not update tree: %w", err)
    }
    s.forest.Save(tree, paths, parent)

    hash := tree.RootHash()
    log.Info().Hex("commit", hash[:]).Int("registers", len(paths)).Msg("updated tree with register payloads")

    return nil
}

// CollectRegisters reads the payloads for the next block to be indexed from the state's forest, unless payload
// indexing is disabled.
func (t *Transitions) CollectRegisters(s *State) error {
    log := t.log.With().Uint64("height", s.height).Hex("commit", s.next[:]).Logger()
    if s.status != StatusCollect {
        return fmt.Errorf("invalid status for collecting registers (%s)", s.status)
    }

    // If indexing payloads is disabled, we can bypass collection and indexing
    // of payloads and just go straight to forwarding the height to the next
    // finalized block.
    if t.cfg.SkipRegisters {
        s.status = StatusForward
        return nil
    }

    // If we index payloads, we are basically stepping back from (and including)
    // the tree that corresponds to the next finalized block all the way up to
    // (and excluding) the tree for the last finalized block we indexed. To do
    // so, we will use the parent state commit to retrieve the parent trees from
    // the forest, and we use the paths we recorded changes on to retrieve the
    // changed payloads at each step.
    commit := s.next
    for commit != s.last {

        // We do this check only once, so that we don't need to do it for
        // each item we retrieve. The tree should always be there, but we
        // should check just to not fail silently.
        ok := s.forest.Has(commit)
        if !ok {
            return fmt.Errorf("could not load tree (commit: %x)", commit)
        }

        // For each path, we retrieve the payload and add it to the registers we
        // will index later. If we already have a payload for the path, it is
        // more recent as we iterate backwards in time, so we can skip the
        // outdated payload.
        // NOTE: We read from the tree one by one here, as the performance
        // overhead is minimal compared to the disk i/o for badger, and it
        // allows us to ignore sorting of paths.
        tree, _ := s.forest.Tree(commit)
        paths, _ := s.forest.Paths(commit)
        for _, path := range paths {
            _, ok := s.registers[path]
            if ok {
                continue
            }
            payloads := tree.UnsafeRead([]ledger.Path{path})
            s.registers[path] = payloads[0]
        }

        log.Debug().Int("batch", len(paths)).Msg("collected register batch for finalized block")

        // We now step back to the parent of the current state trie.
        parent, _ := s.forest.Parent(commit)
        commit = parent
    }

    log.Info().Int("registers", len(s.registers)).Msg("collected all registers for finalized block")

    // At this point, we have collected all the payloads, so we go to the next
    // step, where we will index them.
    s.status = StatusMap
    return nil
}

// MapRegisters maps the collected registers to the current block.
func (t *Transitions) MapRegisters(s *State) error {
    if s.status != StatusMap {
        return fmt.Errorf("invalid status for indexing registers (%s)", s.status)
    }

    log := t.log.With().Uint64("height", s.height).Hex("commit", s.next[:]).Logger()

    // If there are no registers left to be indexed, we can go to the next step,
    // which is about forwarding the height to the next finalized block.
    if len(s.registers) == 0 {
        log.Info().Msg("indexed all registers for finalized block")
        s.status = StatusForward
        return nil
    }

    // We will now collect and index 1000 registers at a time. This gives the
    // FSM the chance to exit the loop between every 1000 payloads we index. It
    // doesn't really matter for badger if they are in random order, so this
    // way of iterating should be fine.
    n := 1000
    paths := make([]ledger.Path, 0, n)
    payloads := make([]*ledger.Payload, 0, n)
    for path, payload := range s.registers {
        paths = append(paths, path)
        payloads = append(payloads, payload)
        delete(s.registers, path)
        if len(paths) >= n {
            break
        }
    }

    // Then we store the (maximum) 1000 paths and payloads.
    err := t.write.Payloads(s.height, paths, payloads)
    if err != nil {
        return fmt.Errorf("could not index registers: %w", err)
    }

    log.Debug().Int("batch", len(paths)).Int("remaining", len(s.registers)).Msg("indexed register batch for finalized block")

    return nil
}

// ForwardHeight increments the height at which the mapping operates, and updates the last indexed height.
func (t *Transitions) ForwardHeight(s *State) error {
    if s.status != StatusForward {
        return fmt.Errorf("invalid status for forwarding height (%s)", s.status)
    }

    // After finishing the indexing of the payloads for a finalized block, or
    // skipping it, we should document the last indexed height. On the first
    // pass, we will also index the first indexed height here.
    var err error
    t.once.Do(func() { err = t.write.First(s.height) })
    if err != nil {
        return fmt.Errorf("could not index first height: %w", err)
    }
    err = t.write.Last(s.height)
    if err != nil {
        return fmt.Errorf("could not index last height: %w", err)
    }

    // Now that we have indexed the heights, we can forward to the next height,
    // and reset the forest to free up memory.
    s.height++
    s.forest.Reset(s.next)

    t.log.Info().Uint64("height", s.height).Msg("forwarded finalized block to next height")

    // Once the height is forwarded, we can set the status so that we index
    // the blockchain data next.
    s.status = StatusIndex
    return nil
}

This refactoring effort allowed us to write simple and concise tests that call a transition function upon the state machine and make assertions upon the resulting state.

mapper_new_internal_test.go
package mapper

import (
    "sync"
    "testing"

    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/require"

    "github.com/onflow/flow-go/ledger"
    "github.com/onflow/flow-go/ledger/complete/mtrie/trie"
    "github.com/onflow/flow-go/model/flow"

    "github.com/optakt/flow-dps/models/dps"
    "github.com/optakt/flow-dps/testing/mocks"
)

func TestNewTransitions(t *testing.T) {
    t.Run("nominal case, without options", func(t *testing.T) {
        load := mocks.BaselineLoader(t)
        chain := mocks.BaselineChain(t)
        feed := mocks.BaselineFeeder(t)
        read := mocks.BaselineReader(t)
        write := mocks.BaselineWriter(t)

        tr := NewTransitions(mocks.NoopLogger, load, chain, feed, read, write)

        assert.NotNil(t, tr)
        assert.Equal(t, chain, tr.chain)
        assert.Equal(t, feed, tr.feed)
        assert.Equal(t, write, tr.write)
        assert.NotNil(t, tr.once)
        assert.Equal(t, DefaultConfig, tr.cfg)
    })

    t.Run("nominal case, with option", func(t *testing.T) {
        load := mocks.BaselineLoader(t)
        chain := mocks.BaselineChain(t)
        feed := mocks.BaselineFeeder(t)
        read := mocks.BaselineReader(t)
        write := mocks.BaselineWriter(t)

        skip := true
        tr := NewTransitions(mocks.NoopLogger, load, chain, feed, read, write,
            WithSkipRegisters(skip),
        )

        assert.NotNil(t, tr)
        assert.Equal(t, chain, tr.chain)
        assert.Equal(t, feed, tr.feed)
        assert.Equal(t, write, tr.write)
        assert.NotNil(t, tr.once)

        assert.NotEqual(t, DefaultConfig, tr.cfg)
        assert.Equal(t, skip, tr.cfg.SkipRegisters)
        assert.Equal(t, DefaultConfig.WaitInterval, tr.cfg.WaitInterval)
    })
}

func TestTransitions_BootstrapState(t *testing.T) {
    t.Run("nominal case", func(t *testing.T) {
        t.Parallel()

        tr, st := baselineFSM(t, StatusBootstrap)

        // Copy state in local scope so that we can override its SaveFunc without impacting other
        // tests running in parallel.
        var saveCalled bool
        forest := mocks.BaselineForest(t, true)
        forest.SaveFunc = func(tree *trie.MTrie, paths []ledger.Path, parent flow.StateCommitment) {
            if !saveCalled {
                assert.True(t, tree.IsEmpty())
                assert.Nil(t, paths)
                assert.Zero(t, parent)
                saveCalled = true
                return
            }
            assert.False(t, tree.IsEmpty())
            assert.Len(t, tree.AllPayloads(), len(paths))
            assert.Len(t, paths, 3) // Expect the three paths from leaves.
            assert.NotZero(t, parent)
        }

        err := tr.BootstrapState(st)
        assert.NoError(t, err)
    })

    t.Run("invalid state", func(t *testing.T) {
        t.Parallel()

        tr, st := baselineFSM(t, StatusForward)

        err := tr.BootstrapState(st)
        assert.Error(t, err)
    })

    t.Run("handles failure to get root height", func(t *testing.T) {
        t.Parallel()

        tr, st := baselineFSM(t, StatusBootstrap)

        chain := mocks.BaselineChain(t)
        chain.RootFunc = func() (uint64, error) {
            return 0, mocks.GenericError
        }

        tr.chain = chain

        err := tr.BootstrapState(st)
        assert.Error(t, err)
    })
}

func TestTransitions_IndexChain(t *testing.T) {
    t.Run("nominal case", func(t *testing.T) {
        t.Parallel()

        chain := mocks.BaselineChain(t)
        chain.HeaderFunc = func(height uint64) (*flow.Header, error) {
            assert.Equal(t, mocks.GenericHeight, height)

            return mocks.GenericHeader, nil
        }
        chain.CommitFunc = func(height uint64) (flow.StateCommitment, error) {
            assert.Equal(t, mocks.GenericHeight, height)

            return mocks.GenericCommit(0), nil
        }
        chain.CollectionsFunc = func(height uint64) ([]*flow.LightCollection, error) {
            assert.Equal(t, mocks.GenericHeight, height)

            return mocks.GenericCollections(2), nil
        }
        chain.GuaranteesFunc = func(height uint64) ([]*flow.CollectionGuarantee, error) {
            assert.Equal(t, mocks.GenericHeight, height)

            return mocks.GenericGuarantees(2), nil
        }
        chain.TransactionsFunc = func(height uint64) ([]*flow.TransactionBody, error) {
            assert.Equal(t, mocks.GenericHeight, height)

            return mocks.GenericTransactions(4), nil
        }
        chain.ResultsFunc = func(height uint64) ([]*flow.TransactionResult, error) {
            assert.Equal(t, mocks.GenericHeight, height)

            return mocks.GenericResults(4), nil
        }
        chain.EventsFunc = func(height uint64) ([]flow.Event, error) {
            assert.Equal(t, mocks.GenericHeight, height)

            return mocks.GenericEvents(8), nil
        }
        chain.SealsFunc = func(height uint64) ([]*flow.Seal, error) {
            assert.Equal(t, mocks.GenericHeight, height)

            return mocks.GenericSeals(4), nil
        }

        write := mocks.BaselineWriter(t)
        write.HeaderFunc = func(height uint64, header *flow.Header) error {
            assert.Equal(t, mocks.GenericHeight, height)
            assert.Equal(t, mocks.GenericHeader, header)

            return nil
        }
        write.CommitFunc = func(height uint64, commit flow.StateCommitment) error {
            assert.Equal(t, mocks.GenericHeight, height)
            assert.Equal(t, mocks.GenericCommit(0), commit)

            return nil
        }
        write.HeightFunc = func(blockID flow.Identifier, height uint64) error {
            assert.Equal(t, mocks.GenericHeight, height)
            assert.Equal(t, mocks.GenericHeader.ID(), blockID)

            return nil
        }
        write.CollectionsFunc = func(height uint64, collections []*flow.LightCollection) error {
            assert.Equal(t, mocks.GenericHeight, height)
            assert.Equal(t, mocks.GenericCollections(2), collections)

            return nil
        }
        write.GuaranteesFunc = func(height uint64, guarantees []*flow.CollectionGuarantee) error {
            assert.Equal(t, mocks.GenericHeight, height)
            assert.Equal(t, mocks.GenericGuarantees(2), guarantees)

            return nil
        }
        write.TransactionsFunc = func(height uint64, transactions []*flow.TransactionBody) error {
            assert.Equal(t, mocks.GenericHeight, height)
            assert.Equal(t, mocks.GenericTransactions(4), transactions)

            return nil
        }
        write.ResultsFunc = func(results []*flow.TransactionResult) error {
            assert.Equal(t, mocks.GenericResults(4), results)

            return nil
        }
        write.EventsFunc = func(height uint64, events []flow.Event) error {
            assert.Equal(t, mocks.GenericHeight, height)
            assert.Equal(t, mocks.GenericEvents(8), events)

            return nil
        }
        write.SealsFunc = func(height uint64, seals []*flow.Seal) error {
            assert.Equal(t, mocks.GenericHeight, height)
            assert.Equal(t, mocks.GenericSeals(4), seals)

            return nil
        }

        tr, st := baselineFSM(t, StatusIndex)
        tr.chain = chain
        tr.write = write

        err := tr.IndexChain(st)

        require.NoError(t, err)
        assert.Equal(t, StatusUpdate, st.status)
    })

    t.Run("handles invalid status", func(t *testing.T) {
        t.Parallel()

        tr, st := baselineFSM(t, StatusBootstrap)

        err := tr.IndexChain(st)

        assert.Error(t, err)
    })

    t.Run("handles chain failure to retrieve commit", func(t *testing.T) {
        t.Parallel()

        chain := mocks.BaselineChain(t)
        chain.CommitFunc = func(uint64) (flow.StateCommitment, error) {
            return flow.DummyStateCommitment, mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusIndex)
        tr.chain = chain

        err := tr.IndexChain(st)

        assert.Error(t, err)
    })

    t.Run("handles writer failure to index commit", func(t *testing.T) {
        t.Parallel()

        write := mocks.BaselineWriter(t)
        write.CommitFunc = func(uint64, flow.StateCommitment) error {
            return mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusIndex)
        tr.write = write

        err := tr.IndexChain(st)

        assert.Error(t, err)
    })

    t.Run("handles chain failure to retrieve header", func(t *testing.T) {
        t.Parallel()

        chain := mocks.BaselineChain(t)
        chain.HeaderFunc = func(uint64) (*flow.Header, error) {
            return nil, mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusIndex)
        tr.chain = chain

        err := tr.IndexChain(st)

        assert.Error(t, err)
    })

    t.Run("handles writer failure to index header", func(t *testing.T) {
        t.Parallel()

        write := mocks.BaselineWriter(t)
        write.HeaderFunc = func(uint64, *flow.Header) error {
            return mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusIndex)
        tr.write = write

        err := tr.IndexChain(st)

        assert.Error(t, err)
    })

    t.Run("handles chain failure to retrieve transactions", func(t *testing.T) {
        t.Parallel()

        chain := mocks.BaselineChain(t)
        chain.TransactionsFunc = func(uint64) ([]*flow.TransactionBody, error) {
            return nil, mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusIndex)
        tr.chain = chain

        err := tr.IndexChain(st)

        assert.Error(t, err)
    })

    t.Run("handles chain failure to retrieve transaction results", func(t *testing.T) {
        t.Parallel()

        chain := mocks.BaselineChain(t)
        chain.ResultsFunc = func(uint64) ([]*flow.TransactionResult, error) {
            return nil, mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusIndex)
        tr.chain = chain

        err := tr.IndexChain(st)

        assert.Error(t, err)
    })

    t.Run("handles writer failure to index transactions", func(t *testing.T) {
        t.Parallel()

        write := mocks.BaselineWriter(t)
        write.ResultsFunc = func([]*flow.TransactionResult) error {
            return mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusIndex)
        tr.write = write

        err := tr.IndexChain(st)

        assert.Error(t, err)
    })

    t.Run("handles chain failure to retrieve collections", func(t *testing.T) {
        t.Parallel()

        chain := mocks.BaselineChain(t)
        chain.CollectionsFunc = func(uint64) ([]*flow.LightCollection, error) {
            return nil, mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusIndex)
        tr.chain = chain

        err := tr.IndexChain(st)

        assert.Error(t, err)
    })

    t.Run("handles writer failure to index collections", func(t *testing.T) {
        t.Parallel()

        write := mocks.BaselineWriter(t)
        write.CollectionsFunc = func(uint64, []*flow.LightCollection) error {
            return mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusIndex)
        tr.write = write

        err := tr.IndexChain(st)

        assert.Error(t, err)
    })

    t.Run("handles chain failure to retrieve guarantees", func(t *testing.T) {
        t.Parallel()

        chain := mocks.BaselineChain(t)
        chain.GuaranteesFunc = func(uint64) ([]*flow.CollectionGuarantee, error) {
            return nil, mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusIndex)
        tr.chain = chain

        err := tr.IndexChain(st)

        assert.Error(t, err)
    })

    t.Run("handles writer failure to index guarantees", func(t *testing.T) {
        t.Parallel()

        write := mocks.BaselineWriter(t)
        write.GuaranteesFunc = func(uint64, []*flow.CollectionGuarantee) error {
            return mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusIndex)
        tr.write = write

        err := tr.IndexChain(st)

        assert.Error(t, err)
    })

    t.Run("handles chain failure to retrieve events", func(t *testing.T) {
        t.Parallel()

        chain := mocks.BaselineChain(t)
        chain.EventsFunc = func(uint64) ([]flow.Event, error) {
            return nil, mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusIndex)
        tr.chain = chain

        err := tr.IndexChain(st)

        assert.Error(t, err)
    })

    t.Run("handles writer failure to index events", func(t *testing.T) {
        t.Parallel()

        write := mocks.BaselineWriter(t)
        write.EventsFunc = func(uint64, []flow.Event) error {
            return mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusIndex)
        tr.write = write

        err := tr.IndexChain(st)

        assert.Error(t, err)
    })

    t.Run("handles chain failure to retrieve seals", func(t *testing.T) {
        t.Parallel()

        chain := mocks.BaselineChain(t)
        chain.SealsFunc = func(uint64) ([]*flow.Seal, error) {
            return nil, mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusIndex)
        tr.chain = chain

        err := tr.IndexChain(st)

        assert.Error(t, err)
    })

    t.Run("handles writer failure to index seals", func(t *testing.T) {
        t.Parallel()

        write := mocks.BaselineWriter(t)
        write.SealsFunc = func(uint64, []*flow.Seal) error {
            return mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusIndex)
        tr.write = write

        err := tr.IndexChain(st)

        assert.Error(t, err)
    })
}

func TestTransitions_UpdateTree(t *testing.T) {
    update := mocks.GenericTrieUpdate(0)
    tree := mocks.GenericTrie

    t.Run("nominal case without match", func(t *testing.T) {
        t.Parallel()

        tr, st := baselineFSM(t, StatusUpdate)

        forest := mocks.BaselineForest(t, false)
        forest.SaveFunc = func(tree *trie.MTrie, paths []ledger.Path, parent flow.StateCommitment) {
            // Parent is RootHash of the mocks.GenericTrie.
            assert.Equal(t, update.RootHash[:], parent[:])
            assert.ElementsMatch(t, paths, update.Paths)
            assert.NotZero(t, tree)
        }
        forest.TreeFunc = func(commit flow.StateCommitment) (*trie.MTrie, bool) {
            assert.Equal(t, update.RootHash[:], commit[:])
            return tree, true
        }
        st.forest = forest

        err := tr.UpdateTree(st)

        require.NoError(t, err)
        assert.Equal(t, StatusUpdate, st.status)
    })

    t.Run("nominal case with no available update temporarily", func(t *testing.T) {
        t.Parallel()

        tr, st := baselineFSM(t, StatusUpdate)

        // Set up the mock feeder to return an unavailable error on the first call and return successfully
        // to subsequent calls.
        var updateCalled bool
        feeder := mocks.BaselineFeeder(t)
        feeder.UpdateFunc = func() (*ledger.TrieUpdate, error) {
            if !updateCalled {
                updateCalled = true
                return nil, dps.ErrUnavailable
            }
            return mocks.GenericTrieUpdate(0), nil
        }
        tr.feed = feeder

        forest := mocks.BaselineForest(t, true)
        forest.HasFunc = func(flow.StateCommitment) bool {
            return updateCalled
        }
        st.forest = forest

        // The first call should not error but should not change the status of the FSM to updating. It should
        // instead remain Updating until a match is found.
        err := tr.UpdateTree(st)

        require.NoError(t, err)
        assert.Equal(t, StatusUpdate, st.status)

        // The second call is now successful and matches.
        err = tr.UpdateTree(st)

        require.NoError(t, err)
        assert.Equal(t, StatusCollect, st.status)
    })

    t.Run("nominal case with match", func(t *testing.T) {
        t.Parallel()

        tr, st := baselineFSM(t, StatusUpdate)

        err := tr.UpdateTree(st)

        require.NoError(t, err)
        assert.Equal(t, StatusCollect, st.status)
    })

    t.Run("handles invalid status", func(t *testing.T) {
        t.Parallel()

        tr, st := baselineFSM(t, StatusBootstrap)

        err := tr.UpdateTree(st)

        assert.Error(t, err)
    })

    t.Run("handles feeder update failure", func(t *testing.T) {
        t.Parallel()

        feed := mocks.BaselineFeeder(t)
        feed.UpdateFunc = func() (*ledger.TrieUpdate, error) {
            return nil, mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusUpdate)
        st.forest = mocks.BaselineForest(t, false)
        tr.feed = feed

        err := tr.UpdateTree(st)

        assert.Error(t, err)
    })

    t.Run("handles forest parent tree not found", func(t *testing.T) {
        t.Parallel()

        forest := mocks.BaselineForest(t, false)
        forest.TreeFunc = func(_ flow.StateCommitment) (*trie.MTrie, bool) {
            return nil, false
        }

        tr, st := baselineFSM(t, StatusUpdate)
        st.forest = forest

        err := tr.UpdateTree(st)

        assert.NoError(t, err)
    })
}

func TestTransitions_CollectRegisters(t *testing.T) {
    t.Run("nominal case", func(t *testing.T) {
        t.Parallel()

        forest := mocks.BaselineForest(t, true)
        forest.ParentFunc = func(commit flow.StateCommitment) (flow.StateCommitment, bool) {
            assert.Equal(t, mocks.GenericCommit(0), commit)

            return mocks.GenericCommit(1), true
        }

        tr, st := baselineFSM(t, StatusCollect)
        st.forest = forest

        err := tr.CollectRegisters(st)

        require.NoError(t, err)
        assert.Equal(t, StatusMap, st.status)
        for _, wantPath := range mocks.GenericLedgerPaths(6) {
            assert.Contains(t, st.registers, wantPath)
        }
    })

    t.Run("indexing payloads disabled", func(t *testing.T) {
        t.Parallel()

        tr, st := baselineFSM(t, StatusCollect)
        tr.cfg.SkipRegisters = true

        err := tr.CollectRegisters(st)

        require.NoError(t, err)
        assert.Empty(t, st.registers)
        assert.Equal(t, StatusForward, st.status)
    })

    t.Run("handles invalid status", func(t *testing.T) {
        t.Parallel()

        tr, st := baselineFSM(t, StatusBootstrap)

        err := tr.CollectRegisters(st)

        assert.Error(t, err)
        assert.Empty(t, st.registers)
    })

    t.Run("handles missing tree for commit", func(t *testing.T) {
        t.Parallel()

        tr, st := baselineFSM(t, StatusCollect)
        st.forest = mocks.BaselineForest(t, false)

        err := tr.CollectRegisters(st)

        assert.Error(t, err)
        assert.Empty(t, st.registers)
    })
}

func TestTransitions_MapRegisters(t *testing.T) {
    t.Run("nominal case with registers to write", func(t *testing.T) {
        t.Parallel()

        // Path 2 and 4 are the same so the map effectively contains 5 entries.
        testRegisters := map[ledger.Path]*ledger.Payload{
            mocks.GenericLedgerPath(0): mocks.GenericLedgerPayload(0),
            mocks.GenericLedgerPath(1): mocks.GenericLedgerPayload(1),
            mocks.GenericLedgerPath(2): mocks.GenericLedgerPayload(2),
            mocks.GenericLedgerPath(1): mocks.GenericLedgerPayload(3),
            mocks.GenericLedgerPath(4): mocks.GenericLedgerPayload(4),
            mocks.GenericLedgerPath(5): mocks.GenericLedgerPayload(5),
        }

        write := mocks.BaselineWriter(t)
        write.PayloadsFunc = func(height uint64, paths []ledger.Path, value []*ledger.Payload) error {
            assert.Equal(t, mocks.GenericHeight, height)

            // Expect the 5 entries from the map.
            assert.Len(t, paths, 5)
            assert.Len(t, value, 5)
            return nil
        }

        tr, st := baselineFSM(t, StatusMap)
        tr.write = write
        st.registers = testRegisters

        err := tr.MapRegisters(st)

        require.NoError(t, err)

        // Should not be StateIndexed because registers map was not empty.
        assert.Empty(t, st.registers)
        assert.Equal(t, StatusMap, st.status)
    })

    t.Run("nominal case no more registers left to write", func(t *testing.T) {
        t.Parallel()

        tr, st := baselineFSM(t, StatusMap)

        err := tr.MapRegisters(st)

        assert.NoError(t, err)
        assert.Equal(t, StatusForward, st.status)
    })

    t.Run("handles invalid status", func(t *testing.T) {
        t.Parallel()

        testRegisters := map[ledger.Path]*ledger.Payload{
            mocks.GenericLedgerPath(0): mocks.GenericLedgerPayload(0),
            mocks.GenericLedgerPath(1): mocks.GenericLedgerPayload(1),
            mocks.GenericLedgerPath(2): mocks.GenericLedgerPayload(2),
            mocks.GenericLedgerPath(3): mocks.GenericLedgerPayload(3),
            mocks.GenericLedgerPath(4): mocks.GenericLedgerPayload(4),
            mocks.GenericLedgerPath(5): mocks.GenericLedgerPayload(5),
        }

        tr, st := baselineFSM(t, StatusBootstrap)
        st.registers = testRegisters

        err := tr.MapRegisters(st)

        assert.Error(t, err)
    })

    t.Run("handles writer failure", func(t *testing.T) {
        t.Parallel()

        testRegisters := map[ledger.Path]*ledger.Payload{
            mocks.GenericLedgerPath(0): mocks.GenericLedgerPayload(0),
            mocks.GenericLedgerPath(1): mocks.GenericLedgerPayload(1),
            mocks.GenericLedgerPath(2): mocks.GenericLedgerPayload(2),
            mocks.GenericLedgerPath(3): mocks.GenericLedgerPayload(3),
            mocks.GenericLedgerPath(4): mocks.GenericLedgerPayload(4),
            mocks.GenericLedgerPath(5): mocks.GenericLedgerPayload(5),
        }

        write := mocks.BaselineWriter(t)
        write.PayloadsFunc = func(uint64, []ledger.Path, []*ledger.Payload) error { return mocks.GenericError }

        tr, st := baselineFSM(t, StatusMap)
        tr.write = write
        st.registers = testRegisters

        err := tr.MapRegisters(st)

        assert.Error(t, err)
    })
}

func TestTransitions_ForwardHeight(t *testing.T) {
    t.Run("nominal case", func(t *testing.T) {
        t.Parallel()

        var (
            firstCalled int
            lastCalled  int
        )
        write := mocks.BaselineWriter(t)
        write.FirstFunc = func(height uint64) error {
            assert.Equal(t, mocks.GenericHeight, height)
            firstCalled++
            return nil
        }
        write.LastFunc = func(height uint64) error {
            assert.Equal(t, mocks.GenericHeight+uint64(lastCalled), height)
            lastCalled++
            return nil
        }

        forest := mocks.BaselineForest(t, true)
        forest.ResetFunc = func(finalized flow.StateCommitment) {
            assert.Equal(t, mocks.GenericCommit(0), finalized)
        }

        tr, st := baselineFSM(t, StatusForward)
        st.forest = forest
        tr.write = write

        err := tr.ForwardHeight(st)

        assert.NoError(t, err)
        assert.Equal(t, StatusIndex, st.status)
        assert.Equal(t, mocks.GenericHeight+1, st.height)

        // Reset status to allow next call.
        st.status = StatusForward
        err = tr.ForwardHeight(st)

        require.NoError(t, err)
        assert.Equal(t, StatusIndex, st.status)
        assert.Equal(t, mocks.GenericHeight+2, st.height)

        // First should have been called only once.
        assert.Equal(t, 1, firstCalled)
    })

    t.Run("handles invalid status", func(t *testing.T) {
        t.Parallel()

        tr, st := baselineFSM(t, StatusBootstrap)

        err := tr.ForwardHeight(st)

        assert.Error(t, err)
    })

    t.Run("handles writer error on first", func(t *testing.T) {
        t.Parallel()

        write := mocks.BaselineWriter(t)
        write.FirstFunc = func(uint64) error {
            return mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusForward)
        tr.write = write

        err := tr.ForwardHeight(st)

        assert.Error(t, err)
    })

    t.Run("handles writer error on last", func(t *testing.T) {
        t.Parallel()

        write := mocks.BaselineWriter(t)
        write.LastFunc = func(uint64) error {
            return mocks.GenericError
        }

        tr, st := baselineFSM(t, StatusForward)
        tr.write = write

        err := tr.ForwardHeight(st)

        assert.Error(t, err)
    })
}

func TestTransitions_InitializeMapper(t *testing.T) {
    t.Run("switches state to BootstrapState if configured to do so", func(t *testing.T) {
        t.Parallel()

        tr, st := baselineFSM(t, StatusInitialize)

        tr.cfg.BootstrapState = true

        err := tr.InitializeMapper(st)

        require.NoError(t, err)
        assert.Equal(t, StatusBootstrap, st.status)
    })

    t.Run("switches state to StatusResume if no bootstrapping configured", func(t *testing.T) {
        t.Parallel()

        tr, st := baselineFSM(t, StatusInitialize)

        tr.cfg.BootstrapState = false

        err := tr.InitializeMapper(st)

        require.NoError(t, err)
        assert.Equal(t, StatusResume, st.status)
    })

    t.Run("handles invalid status", func(t *testing.T) {
        t.Parallel()

        tr, st := baselineFSM(t, StatusForward)

        err := tr.InitializeMapper(st)

        require.Error(t, err)
    })
}

func TestTransitions_ResumeIndexing(t *testing.T) {
    header := mocks.GenericHeader
    tree := mocks.GenericTrie
    commit := flow.StateCommitment(tree.RootHash())
    differentCommit := mocks.GenericCommit(0)

    t.Run("nominal case", func(t *testing.T) {
        t.Parallel()

        chain := mocks.BaselineChain(t)
        chain.RootFunc = func() (uint64, error) {
            return header.Height, nil
        }

        writer := mocks.BaselineWriter(t)
        writer.FirstFunc = func(height uint64) error {
            assert.Equal(t, header.Height, height)

            return nil
        }

        loader := mocks.BaselineLoader(t)
        loader.TrieFunc = func() (*trie.MTrie, error) {
            return tree, nil
        }

        reader := mocks.BaselineReader(t)
        reader.LastFunc = func() (uint64, error) {
            return header.Height, nil
        }
        reader.CommitFunc = func(height uint64) (flow.StateCommitment, error) {
            assert.Equal(t, header.Height, height)

            return commit, nil
        }

        tr, st := baselineFSM(
            t,
            StatusResume,
            withReader(reader),
            withWriter(writer),
            withLoader(loader),
            withChain(chain),
        )

        err := tr.ResumeIndexing(st)

        require.NoError(t, err)
        assert.Equal(t, StatusIndex, st.status)
        assert.Equal(t, header.Height+1, st.height)
        assert.Equal(t, flow.DummyStateCommitment, st.last)
        assert.Equal(t, commit, st.next)
    })

    t.Run("handles chain failure on Root", func(t *testing.T) {
        t.Parallel()

        chain := mocks.BaselineChain(t)
        chain.RootFunc = func() (uint64, error) {
            return 0, mocks.GenericError
        }

        loader := mocks.BaselineLoader(t)
        loader.TrieFunc = func() (*trie.MTrie, error) {
            return tree, nil
        }

        reader := mocks.BaselineReader(t)
        reader.LastFunc = func() (uint64, error) {
            return header.Height, nil
        }
        reader.CommitFunc = func(uint64) (flow.StateCommitment, error) {
            return commit, nil
        }

        tr, st := baselineFSM(
            t,
            StatusResume,
            withReader(reader),
            withLoader(loader),
            withChain(chain),
        )

        err := tr.ResumeIndexing(st)

        assert.Error(t, err)
    })

    t.Run("handles writer failure on First", func(t *testing.T) {
        t.Parallel()

        chain := mocks.BaselineChain(t)
        chain.RootFunc = func() (uint64, error) {
            return header.Height, nil
        }

        writer := mocks.BaselineWriter(t)
        writer.FirstFunc = func(uint64) error {
            return mocks.GenericError
        }

        loader := mocks.BaselineLoader(t)
        loader.TrieFunc = func() (*trie.MTrie, error) {
            return tree, nil
        }

        reader := mocks.BaselineReader(t)
        reader.LastFunc = func() (uint64, error) {
            return header.Height, nil
        }
        reader.CommitFunc = func(uint64) (flow.StateCommitment, error) {
            return commit, nil
        }

        tr, st := baselineFSM(
            t,
            StatusResume,
            withWriter(writer),
            withReader(reader),
            withLoader(loader),
            withChain(chain),
        )

        err := tr.ResumeIndexing(st)

        assert.Error(t, err)
    })

    t.Run("handles reader failure on Last", func(t *testing.T) {
        t.Parallel()

        chain := mocks.BaselineChain(t)
        chain.RootFunc = func() (uint64, error) {
            return header.Height, nil
        }

        loader := mocks.BaselineLoader(t)
        loader.TrieFunc = func() (*trie.MTrie, error) {
            return tree, nil
        }

        reader := mocks.BaselineReader(t)
        reader.LastFunc = func() (uint64, error) {
            return 0, mocks.GenericError
        }
        reader.CommitFunc = func(uint64) (flow.StateCommitment, error) {
            return commit, nil
        }

        tr, st := baselineFSM(
            t,
            StatusResume,
            withReader(reader),
            withLoader(loader),
            withChain(chain),
        )

        err := tr.ResumeIndexing(st)

        assert.Error(t, err)
    })

    t.Run("handles reader failure on Commit", func(t *testing.T) {
        t.Parallel()

        chain := mocks.BaselineChain(t)
        chain.RootFunc = func() (uint64, error) {
            return header.Height, nil
        }

        loader := mocks.BaselineLoader(t)
        loader.TrieFunc = func() (*trie.MTrie, error) {
            return tree, nil
        }

        reader := mocks.BaselineReader(t)
        reader.LastFunc = func() (uint64, error) {
            return header.Height, nil
        }
        reader.CommitFunc = func(uint64) (flow.StateCommitment, error) {
            return flow.DummyStateCommitment, mocks.GenericError
        }

        tr, st := baselineFSM(
            t,
            StatusResume,
            withReader(reader),
            withLoader(loader),
            withChain(chain),
        )

        err := tr.ResumeIndexing(st)

        assert.Error(t, err)
    })

    t.Run("handles loader failure on Trie", func(t *testing.T) {
        t.Parallel()

        chain := mocks.BaselineChain(t)
        chain.RootFunc = func() (uint64, error) {
            return header.Height, nil
        }

        loader := mocks.BaselineLoader(t)
        loader.TrieFunc = func() (*trie.MTrie, error) {
            return nil, mocks.GenericError
        }

        reader := mocks.BaselineReader(t)
        reader.LastFunc = func() (uint64, error) {
            return header.Height, nil
        }
        reader.CommitFunc = func(uint64) (flow.StateCommitment, error) {
            return commit, nil
        }

        tr, st := baselineFSM(
            t,
            StatusResume,
            withReader(reader),
            withLoader(loader),
            withChain(chain),
        )

        err := tr.ResumeIndexing(st)

        assert.Error(t, err)
    })

    t.Run("handles mismatch between tree root hash and indexed commit", func(t *testing.T) {
        t.Parallel()

        chain := mocks.BaselineChain(t)
        chain.RootFunc = func() (uint64, error) {
            return header.Height, nil
        }

        loader := mocks.BaselineLoader(t)
        loader.TrieFunc = func() (*trie.MTrie, error) {
            return tree, nil
        }

        reader := mocks.BaselineReader(t)
        reader.LastFunc = func() (uint64, error) {
            return header.Height, nil
        }
        reader.CommitFunc = func(uint64) (flow.StateCommitment, error) {
            return differentCommit, nil
        }

        tr, st := baselineFSM(
            t,
            StatusResume,
            withReader(reader),
            withLoader(loader),
            withChain(chain),
        )

        err := tr.ResumeIndexing(st)

        assert.Error(t, err)
    })

    t.Run("handles invalid status", func(t *testing.T) {
        t.Parallel()

        chain := mocks.BaselineChain(t)
        chain.RootFunc = func() (uint64, error) {
            return header.Height, nil
        }

        loader := mocks.BaselineLoader(t)
        loader.TrieFunc = func() (*trie.MTrie, error) {
            return tree, nil
        }

        reader := mocks.BaselineReader(t)
        reader.LastFunc = func() (uint64, error) {
            return header.Height, nil
        }
        reader.CommitFunc = func(uint64) (flow.StateCommitment, error) {
            return commit, nil
        }

        tr, st := baselineFSM(
            t,
            StatusForward,
            withReader(reader),
            withLoader(loader),
            withChain(chain),
        )

        err := tr.ResumeIndexing(st)

        assert.Error(t, err)
    })
}

func baselineFSM(t *testing.T, status Status, opts ...func(tr *Transitions)) (*Transitions, *State) {
    t.Helper()

    load := mocks.BaselineLoader(t)
    chain := mocks.BaselineChain(t)
    feeder := mocks.BaselineFeeder(t)
    read := mocks.BaselineReader(t)
    write := mocks.BaselineWriter(t)
    forest := mocks.BaselineForest(t, true)

    once := &sync.Once{}
    doneCh := make(chan struct{})

    tr := Transitions{
        cfg: Config{
            BootstrapState: false,
            SkipRegisters:  false,
            WaitInterval:   0,
        },
        log:   mocks.NoopLogger,
        load:  load,
        chain: chain,
        feed:  feeder,
        read:  read,
        write: write,
        once:  once,
    }

    for _, opt := range opts {
        opt(&tr)
    }

    st := State{
        forest:    forest,
        status:    status,
        height:    mocks.GenericHeight,
        last:      mocks.GenericCommit(1),
        next:      mocks.GenericCommit(0),
        registers: make(map[ledger.Path]*ledger.Payload),
        done:      doneCh,
    }

    return &tr, &st
}

func withLoader(load Loader) func(*Transitions) {
    return func(tr *Transitions) {
        tr.load = load
    }
}

func withChain(chain dps.Chain) func(*Transitions) {
    return func(tr *Transitions) {
        tr.chain = chain
    }
}

func withFeeder(feed Feeder) func(*Transitions) {
    return func(tr *Transitions) {
        tr.feed = feed
    }
}

func withReader(read dps.Reader) func(*Transitions) {
    return func(tr *Transitions) {
        tr.read = read
    }
}

func withWriter(write dps.Writer) func(*Transitions) {
    return func(tr *Transitions) {
        tr.write = write
    }
}

Integration Tests

Integration tests are essential to ensure that components work together as expected. Those tests are usually much heavier and slower than unit tests, since they use real components instead of simple mocks, and often might run filesystem or network operations, wait for things to happen, or even run heavy computational tasks.

Integration tests should always be specified in a separate test package and never run internally within the tested package.

Build Tag

Because integration tests are inherently slower than unit tests, they are placed in specific files that are suffixed with _integration_test.go and those files start with a build tag directive which prevents them from running unless the go test command is called with the integration tag.

Both syntaxes should be specified, the <go1.17 one which is +build <tag> as well as the >=go1.17 one which is go:build <tag>. The former will be dropped when we feel like it is no longer relevant to support go 1.16 and prior.

//go:build integration
// +build integration

package dps_test

Examples

In Go, good package documentation includes not only comments for each public type and method, but also runnable examples and benchmarks in some cases. Godoc allows defining examples which are verified by running them as tests and can be manually launched by readers of the documentation on the package's Godoc webpage.

As for typical tests, examples are functions that reside in a package's _test.go files. Unlike normal test functions, though, example functions take no arguments and begin with the word Example instead of Test.

In order to specify what is the expected output of a given example, a comment has to be written at the end of the Example function, in the form of // Output: <expected output>. If this is missing, examples will not be executed and therefore not included in the documentation.

Benchmarks

When a package exposes a performance-critical piece of code, it should be benchmarked, and benchmark tests must be available for anyone to reproduce the benchmark using their hardware. Writing benchmark results in a markdown file without providing a way to reproduce them is irrelevant.

    // test_trie.go
    func BenchmarkTrie_InsertMany(b *testing.B) {

        paths, payloads := helpers.SampleRandomRegisterWrites(helpers.NewGenerator(), 1000)

        b.Run("insert 1000 elements (reference)", func(b *testing.B) {
            for i := 0; i < b.N; i++ {
                ref := reference.NewEmptyMTrie()
                ref, _ = reference.NewTrieWithUpdatedRegisters(ref, paths, payloads)
                _ = ref.RootHash()
            }
        })

        b.Run("insert 1000 elements (new)", func(b *testing.B) {
            for i := 0; i < b.N; i++ {
                tr := trie.NewEmptyTrie()
                tr, _ = tr.Mutate(paths, payloads)
                _ = tr.RootHash()
            }
        })
    }