Restrict Direct Call to a Method Inside Test in Go

Dec 6, 2024  │  m. Dec 6, 2024 by عين  │  #go   #golang   #testing   #runtime  
Disclaimer: Views expressed in this software engineering blog are personal and do not represent my employer. Readers are encouraged to verify information independently.

Introduction

Hello, I am going to show a cool trick to restrict direct call to a method inside test in Go. This is useful when you want to make sure that a method is only called from production code and not tests.

Problem

I will put this code, our objective is simple we need to make sure that MyStruct doer have done and confirmed the end of their job.

package demo

import "fmt"

type Doer interface {
	Do()
	EndDo()
}

type MyStruct struct {
	D Doer
}

type DoPrint struct{}

func (d DoPrint) Do() {
	fmt.Println("Doing")
}

func (d DoPrint) EndDo() {
	fmt.Println("End Doing")
}

func (s *MyStruct) Run() {
	s.D.Do()
	defer s.D.EndDo()
}

in other word we need to make sure that when Do is called EndDo is also called.
to verify this let’s write a unit test.

package demo_test

import (
	"demo"
	"testing"
)

type MockDoPrint struct {
	t               *testing.T
	doCallsCount    int
	endDoCallsCount int
}

func (d *MockDoPrint) Do() {
	d.doCallsCount++
}

func (d *MockDoPrint) EndDo() {
	d.endDoCallsCount++
}

func TestRun(t *testing.T) {
    mockDoPrint := &MockDoPrint{
        t: t,
    }
	demo := demo.MyStruct{
		D: mockDoPrint,
	}
	demo.Run()

	t.Run("Do() call count", func(t *testing.T) {
		want := 1
		got := mockDoPrint.doCallsCount
		if got != want {
			t.Fatalf("Do() call count = %d; want %d", got, want)
		}
	})

	t.Run("EndDo() call count", func(t *testing.T) {
		want := 1
		got := mockDoPrint.endDoCallsCount
		if got != want {
			t.Fatalf("EndDo() call count = %d; want %d", got, want)
		}
	})
}

the test is also straight forward we mock the Doer interface and then we call Run method and then we verify that Do and EndDo is called once inside the Run method.

all cool, let’s run our tests

go test -v ./...

we can see the following output

=== RUN   TestRun
=== RUN   TestRun/Do()_call_count
=== RUN   TestRun/EndDo()_call_count
--- PASS: TestRun (0.00s)
    --- PASS: TestRun/Do()_call_count (0.00s)
    --- PASS: TestRun/EndDo()_call_count (0.00s)
PASS
ok      demo    0.003s

now let’s say that we forgot to defer the EndDo call inside our run method.

func (s *MyStruct) Run() {
	s.D.Do()
    //defer s.D.EndDo()
}

Running our test again with the same command go test -v ./... will give us the following output

=== RUN   TestRun
=== RUN   TestRun/Do()_call_count
=== RUN   TestRun/EndDo()_call_count
    demo_test.go:40: EndDo() call count = 0; want 1
--- FAIL: TestRun (0.00s)
    --- PASS: TestRun/Do()_call_count (0.00s)
    --- FAIL: TestRun/EndDo()_call_count (0.00s)
FAIL
FAIL    demo    0.003s
FAIL

the EndDo test will fail because the developer forgot to call the EndDo method inside the call.
All good until now, let’s fix it. Inside the test itself let’s add demo.D.EndDo() after calling demo.Run()

func TestRun(t *testing.T) {
    // ...	
	demo.Run()

	demo.D.EndDo()

Running the test again will give us the following output

=== RUN   TestRun
=== RUN   TestRun/Do()_call_count
=== RUN   TestRun/EndDo()_call_count
--- PASS: TestRun (0.00s)
    --- PASS: TestRun/Do()_call_count (0.00s)
    --- PASS: TestRun/EndDo()_call_count (0.00s)
PASS

if you are thinking HEY! this is not the correct fix!
our test will pass but production code still have the bug, you are right, this is not the correct fix.
Imagine :) the test code is big and a developer called by mistake the EndDo method inside the test, that what I am trying to prevent.

Solution

Let’s prevent this from happening by restricting the direct call to the EndDo method inside the test.

func (d *MockDoPrint) Do() {
	_, file, _, _ := runtime.Caller(1)
	if strings.Contains(file, "_test.go") {
		d.t.Fatal("Do() should not be called in tests")
	}
	d.doCallsCount++
}

func (d *MockDoPrint) EndDo() {
	_, file, _, _ := runtime.Caller(1)
	if strings.Contains(file, "_test.go") {
		d.t.Fatal("EndDo() should not be called in tests")
	}
	d.endDoCallsCount++
}

the runtime.Caller function will return the file name of the caller function, we can use this to check if the caller is a test file or not.

let’s rerun the test again

=== RUN   TestRun
    demo_test.go:27: EndDo() should not be called in tests
--- FAIL: TestRun (0.00s)
FAIL
FAIL    demo    0.003s
FAIL

and this was catched by our test, the test will fail because the EndDo method is called inside the test.
like that we can make sure that the EndDo method is only called inside the production code and not the test code.

Kanna Kamui

I know there is many ways to reach the same result, like writing the code or the test in another way, but this is a cool trick that I wanted to share with you.

You can download the source code from here .