diff --git a/go.mod b/go.mod index 777a3edb1..e43ea7b95 100644 --- a/go.mod +++ b/go.mod @@ -120,6 +120,7 @@ require ( github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/shopspring/decimal v1.4.0 // indirect github.com/spf13/cast v1.7.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect diff --git a/internal/gitprovider/github/github.go b/internal/gitprovider/github/github.go index d97a05bae..99ac7ebdb 100644 --- a/internal/gitprovider/github/github.go +++ b/internal/gitprovider/github/github.go @@ -41,11 +41,42 @@ func init() { gitprovider.Register(ProviderName, registration) } +type githubClient interface { + CreatePullRequest( + ctx context.Context, + owner string, + repo string, + pull *github.NewPullRequest, + ) (*github.PullRequest, *github.Response, error) + + ListPullRequests( + ctx context.Context, + owner string, + repo string, + opts *github.PullRequestListOptions, + ) ([]*github.PullRequest, *github.Response, error) + + GetPullRequests( + ctx context.Context, + owner string, + repo string, + number int, + ) (*github.PullRequest, *github.Response, error) + + AddLabelsToIssue( + ctx context.Context, + owner string, + repo string, + number int, + labels []string, + ) ([]*github.Label, *github.Response, error) +} + // provider is a GitHub implementation of gitprovider.Interface. type provider struct { // nolint: revive owner string repo string - client *github.Client + client githubClient } // NewProvider returns a GitHub-based implementation of gitprovider.Interface. @@ -81,10 +112,30 @@ func NewProvider( return &provider{ owner: owner, repo: repo, - client: client, + client: &githubClientWrapper{client}, }, nil } +type githubClientWrapper struct { + client *github.Client +} + +func (g githubClientWrapper) CreatePullRequest(ctx context.Context, owner string, repo string, pull *github.NewPullRequest) (*github.PullRequest, *github.Response, error) { + return g.client.PullRequests.Create(ctx, owner, repo, pull) +} + +func (g githubClientWrapper) ListPullRequests(ctx context.Context, owner string, repo string, opts *github.PullRequestListOptions) ([]*github.PullRequest, *github.Response, error) { + return g.client.PullRequests.List(ctx, owner, repo, opts) +} + +func (g githubClientWrapper) GetPullRequests(ctx context.Context, owner string, repo string, number int) (*github.PullRequest, *github.Response, error) { + return g.client.PullRequests.Get(ctx, owner, repo, number) +} + +func (g githubClientWrapper) AddLabelsToIssue(ctx context.Context, owner string, repo string, number int, labels []string) ([]*github.Label, *github.Response, error) { + return g.client.Issues.AddLabelsToIssue(ctx, owner, repo, number, labels) +} + // CreatePullRequest implements gitprovider.Interface. func (p *provider) CreatePullRequest( ctx context.Context, @@ -93,7 +144,7 @@ func (p *provider) CreatePullRequest( if opts == nil { opts = &gitprovider.CreatePullRequestOpts{} } - ghPR, _, err := p.client.PullRequests.Create(ctx, + ghPR, _, err := p.client.CreatePullRequest(ctx, p.owner, p.repo, &github.NewPullRequest{ @@ -112,7 +163,7 @@ func (p *provider) CreatePullRequest( } pr := convertGithubPR(*ghPR) if len(opts.Labels) > 0 { - _, _, err = p.client.Issues.AddLabelsToIssue(ctx, + _, _, err = p.client.AddLabelsToIssue(ctx, p.owner, p.repo, int(pr.Number), @@ -130,7 +181,7 @@ func (p *provider) GetPullRequest( ctx context.Context, id int64, ) (*gitprovider.PullRequest, error) { - ghPR, _, err := p.client.PullRequests.Get(ctx, p.owner, p.repo, int(id)) + ghPR, _, err := p.client.GetPullRequests(ctx, p.owner, p.repo, int(id)) if err != nil { return nil, err } @@ -171,7 +222,7 @@ func (p *provider) ListPullRequests( } prs := []gitprovider.PullRequest{} for { - ghPRs, res, err := p.client.PullRequests.List(ctx, p.owner, p.repo, &listOpts) + ghPRs, res, err := p.client.ListPullRequests(ctx, p.owner, p.repo, &listOpts) if err != nil { return nil, err } diff --git a/internal/gitprovider/github/github_test.go b/internal/gitprovider/github/github_test.go index f35753b90..0f40b01db 100644 --- a/internal/gitprovider/github/github_test.go +++ b/internal/gitprovider/github/github_test.go @@ -1,11 +1,18 @@ package github import ( + "context" + "github.com/akuity/kargo/internal/gitprovider" + "github.com/google/go-github/v56/github" "testing" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) +const testRepoOwner = "akuity" +const testRepoName = "kargo" + func TestParseGitHubURL(t *testing.T) { testCases := []struct { url string @@ -57,3 +64,240 @@ func TestParseGitHubURL(t *testing.T) { }) } } + +type mockGithubClient struct { + mock.Mock + pr *github.PullRequest + owner string + repo string + newPr *github.NewPullRequest + labels []string + listOpts *github.PullRequestListOptions +} + +func (m *mockGithubClient) ListPullRequests(ctx context.Context, owner string, repo string, opts *github.PullRequestListOptions) ([]*github.PullRequest, *github.Response, error) { + args := m.Called(ctx, owner, repo, opts) + m.owner = owner + m.repo = repo + m.listOpts = opts + return args.Get(0).([]*github.PullRequest), args.Get(1).(*github.Response), args.Error(2) +} + +func (m *mockGithubClient) GetPullRequests( + ctx context.Context, + owner string, + repo string, + number int, +) (*github.PullRequest, *github.Response, error) { + args := m.Called(ctx, owner, repo, number) + m.owner = owner + m.repo = repo + return args.Get(0).(*github.PullRequest), args.Get(1).(*github.Response), args.Error(2) +} + +func (m *mockGithubClient) AddLabelsToIssue( + ctx context.Context, + owner string, + repo string, + number int, + labels []string, +) ([]*github.Label, *github.Response, error) { + args := m.Called(ctx, owner, repo, number, labels) + m.labels = labels + return args.Get(0).([]*github.Label), args.Get(1).(*github.Response), args.Error(2) +} + +func (m *mockGithubClient) CreatePullRequest( + ctx context.Context, + owner string, + repo string, + pull *github.NewPullRequest) (*github.PullRequest, *github.Response, error) { + args := m.Called(ctx, owner, repo, pull) + m.owner = owner + m.repo = repo + m.newPr = pull + return args.Get(0).(*github.PullRequest), args.Get(1).(*github.Response), args.Error(2) +} + +func TestCreatePullRequestWithLabels(t *testing.T) { + opts := gitprovider.CreatePullRequestOpts{ + Head: "feature-branch", + Base: "main", + Title: "title", + Description: "desc", + Labels: []string{"label1", "label2"}, + } + + // set up mock + mockClient := &mockGithubClient{ + pr: &github.PullRequest{ + Number: github.Int(42), + MergeCommitSHA: github.String("sha"), + State: github.String("open"), + URL: github.String("url"), + }, + } + mockClient. + On("CreatePullRequest", context.Background(), testRepoOwner, testRepoName, mock.Anything). + Return( + &github.PullRequest{ + Number: mockClient.pr.Number, + Head: &github.PullRequestBranch{ + Ref: github.String(opts.Head), + }, + Base: &github.PullRequestBranch{ + Ref: github.String(opts.Base), + }, + Title: github.String(opts.Title), + Body: github.String(opts.Description), + MergeCommitSHA: mockClient.pr.MergeCommitSHA, + State: mockClient.pr.State, + HTMLURL: mockClient.pr.URL, + }, + &github.Response{}, + nil, + ) + mockClient. + On("AddLabelsToIssue", context.Background(), testRepoOwner, testRepoName, *mockClient.pr.Number, mock.Anything). + Return( + []*github.Label{}, + &github.Response{}, + nil, + ) + + // call the code we are testing + g := provider{ + owner: testRepoOwner, + repo: testRepoName, + client: mockClient, + } + pr, err := g.CreatePullRequest(context.Background(), &opts) + + // assert that the expectations were met + mockClient.AssertExpectations(t) + + // other assertions + require.NoError(t, err) + require.Equal(t, testRepoOwner, mockClient.owner) + require.Equal(t, testRepoName, mockClient.repo) + require.Equal(t, opts.Head, *mockClient.newPr.Head) + require.Equal(t, opts.Base, *mockClient.newPr.Base) + require.Equal(t, opts.Title, *mockClient.newPr.Title, "Expected title in new PR request to match title from options") + require.Equal(t, opts.Description, *mockClient.newPr.Body, "Expected body in new PR request to match description from options") + require.ElementsMatch(t, opts.Labels, mockClient.labels, "Expected labels passed to GitHub client to match labels from options") + + require.Equal(t, int64(*mockClient.pr.Number), pr.Number, "Expected PR number in returned object to match what was returned by GitHub") + require.Equal(t, *mockClient.pr.MergeCommitSHA, pr.MergeCommitSHA) + require.Equal(t, *mockClient.pr.URL, pr.URL) + require.True(t, pr.Open) +} + +func TestGetPullRequest(t *testing.T) { + // set up mock + mockClient := &mockGithubClient{ + pr: &github.PullRequest{ + Number: github.Int(42), + MergeCommitSHA: github.String("sha"), + State: github.String("open"), + URL: github.String("url"), + }, + } + mockClient. + On("GetPullRequests", context.Background(), testRepoOwner, testRepoName, *mockClient.pr.Number). + Return( + &github.PullRequest{ + Number: mockClient.pr.Number, + Head: &github.PullRequestBranch{ + Ref: github.String("head"), + }, + MergeCommitSHA: mockClient.pr.MergeCommitSHA, + State: mockClient.pr.State, + HTMLURL: mockClient.pr.URL, + }, + &github.Response{}, + nil, + ) + + // call the code we are testing + g := provider{ + owner: testRepoOwner, + repo: testRepoName, + client: mockClient, + } + pr, err := g.GetPullRequest(context.Background(), 42) + + // assert that the expectations were met + mockClient.AssertExpectations(t) + + // other assertions + require.NoError(t, err) + require.Equal(t, testRepoOwner, mockClient.owner) + require.Equal(t, testRepoName, mockClient.repo) + require.Equal(t, int64(*mockClient.pr.Number), pr.Number, "Expected PR number in returned object to match what was returned by GitHub") + require.Equal(t, *mockClient.pr.MergeCommitSHA, pr.MergeCommitSHA) + require.Equal(t, *mockClient.pr.URL, pr.URL) + require.True(t, pr.Open) +} + +func TestListPullRequests(t *testing.T) { + opts := gitprovider.ListPullRequestOptions{ + State: gitprovider.PullRequestStateAny, + HeadBranch: "head", + BaseBranch: "base", + } + + // set up mock + mockClient := &mockGithubClient{ + pr: &github.PullRequest{ + Number: github.Int(42), + MergeCommitSHA: github.String("sha"), + State: github.String("open"), + URL: github.String("url"), + }, + } + mockClient. + On("ListPullRequests", context.Background(), testRepoOwner, testRepoName, &github.PullRequestListOptions{ + State: "all", + Head: opts.HeadBranch, + Base: opts.BaseBranch, + Sort: "", + Direction: "", + ListOptions: github.ListOptions{ + Page: 0, + PerPage: 100, + }, + }). + Return( + []*github.PullRequest{{ + Number: mockClient.pr.Number, + Head: &github.PullRequestBranch{ + Ref: github.String("head"), + }, + MergeCommitSHA: mockClient.pr.MergeCommitSHA, + State: mockClient.pr.State, + HTMLURL: mockClient.pr.URL, + }}, + &github.Response{}, + nil, + ) + + // call the code we are testing + g := provider{ + owner: testRepoOwner, + repo: testRepoName, + client: mockClient, + } + + prs, err := g.ListPullRequests(context.Background(), &opts) + require.NoError(t, err) + + require.Equal(t, testRepoOwner, mockClient.owner) + require.Equal(t, testRepoName, mockClient.repo) + require.Equal(t, opts.HeadBranch, mockClient.listOpts.Head) + require.Equal(t, opts.BaseBranch, mockClient.listOpts.Base) + + require.Equal(t, int64(*mockClient.pr.Number), prs[0].Number) + require.Equal(t, *mockClient.pr.MergeCommitSHA, prs[0].MergeCommitSHA) + require.Equal(t, *mockClient.pr.URL, prs[0].URL) + require.True(t, prs[0].Open) +}