package client

import (
	"context"
	"encoding/json"
	"net/http"
	"net/http/httptest"
	"os"
	"path/filepath"
	"runtime"
	"strings"
	"testing"
	"time"

	"github.com/containerd/platforms"
	"github.com/moby/buildkit/client/llb"
	"github.com/moby/buildkit/client/llb/sourceresolver"
	gateway "github.com/moby/buildkit/frontend/gateway/client"
	pb "github.com/moby/buildkit/frontend/gateway/pb"
	opspb "github.com/moby/buildkit/solver/pb"
	sourcepolicypb "github.com/moby/buildkit/sourcepolicy/pb"
	"github.com/moby/buildkit/sourcepolicy/policysession"
	"github.com/moby/buildkit/util/testutil/integration"
	digest "github.com/opencontainers/go-digest"
	ocispecs "github.com/opencontainers/image-spec/specs-go/v1"
	"github.com/pkg/errors"
	"github.com/stretchr/testify/require"
)

func testSourcePolicySession(t *testing.T, sb integration.Sandbox) {
	requiresLinux(t)

	ctx := sb.Context()

	c, err := New(ctx, sb.Address())
	require.NoError(t, err)
	defer c.Close()

	type tcase struct {
		name          string
		state         func() llb.State
		callbacks     []policysession.PolicyCallback
		expectedError string
	}

	tcases := []tcase{
		{
			name:  "basic alpine",
			state: func() llb.State { return llb.Image("alpine") },
			callbacks: []policysession.PolicyCallback{
				func(ctx context.Context, req *policysession.CheckPolicyRequest) (*policysession.DecisionResponse, *pb.ResolveSourceMetaRequest, error) {
					require.Equal(t, runtime.GOOS, req.Platform.OS)
					require.Equal(t, runtime.GOARCH, req.Platform.Architecture)

					require.Equal(t, "docker-image://docker.io/library/alpine:latest", req.Source.Source.Identifier)
					return &policysession.DecisionResponse{
						Action: sourcepolicypb.PolicyAction_ALLOW,
					}, nil, nil
				},
			},
		},
		{
			name:  "alpine with attrs",
			state: func() llb.State { return llb.Image("alpine", llb.WithLayerLimit(1)) },
			callbacks: []policysession.PolicyCallback{
				func(ctx context.Context, req *policysession.CheckPolicyRequest) (*policysession.DecisionResponse, *pb.ResolveSourceMetaRequest, error) {
					require.Equal(t, "docker-image://docker.io/library/alpine:latest", req.Source.Source.Identifier)
					require.Equal(t, map[string]string{
						"image.layerlimit": "1",
					}, req.Source.Source.Attrs)
					return &policysession.DecisionResponse{
						Action: sourcepolicypb.PolicyAction_ALLOW,
					}, nil, nil
				},
			},
		},
		{
			name:  "deny alpine",
			state: func() llb.State { return llb.Image("alpine") },
			callbacks: []policysession.PolicyCallback{
				func(ctx context.Context, req *policysession.CheckPolicyRequest) (*policysession.DecisionResponse, *pb.ResolveSourceMetaRequest, error) {
					require.Equal(t, "docker-image://docker.io/library/alpine:latest", req.Source.Source.Identifier)
					return nil, nil, errors.New("policy denied")
				},
			},
			expectedError: "policy denied",
		},
		{
			name:  "alpine with digest policy",
			state: func() llb.State { return llb.Image("alpine") },
			callbacks: []policysession.PolicyCallback{
				func(ctx context.Context, req *policysession.CheckPolicyRequest) (*policysession.DecisionResponse, *pb.ResolveSourceMetaRequest, error) {
					require.Equal(t, "docker-image://docker.io/library/alpine:latest", req.Source.Source.Identifier)
					require.Nil(t, req.Source.Image)
					return nil, &pb.ResolveSourceMetaRequest{
						Source:   req.Source.Source,
						Platform: req.Platform,
					}, nil
				},
				func(ctx context.Context, req *policysession.CheckPolicyRequest) (*policysession.DecisionResponse, *pb.ResolveSourceMetaRequest, error) {
					require.Equal(t, "docker-image://docker.io/library/alpine:latest", req.Source.Source.Identifier)
					require.NotEmpty(t, req.Source.Image.Digest)
					_, err := digest.Parse(req.Source.Image.Digest)
					require.NoError(t, err)
					require.NotEmpty(t, req.Source.Image.Config)
					var cfg ocispecs.Image
					err = json.Unmarshal(req.Source.Image.Config, &cfg)
					require.NoError(t, err)
					require.NotEmpty(t, cfg.RootFS)
					return &policysession.DecisionResponse{
						Action: sourcepolicypb.PolicyAction_ALLOW,
					}, nil, nil
				},
			},
		},
	}

	for _, tc := range tcases {
		t.Run(tc.name, func(t *testing.T) {
			st := tc.state()
			def, err := st.Marshal(ctx)
			require.NoError(t, err)

			callCounter := 0

			p := policysession.NewPolicyProvider(func(ctx context.Context, req *policysession.CheckPolicyRequest) (*policysession.DecisionResponse, *pb.ResolveSourceMetaRequest, error) {
				if callCounter >= len(tc.callbacks) {
					return nil, nil, errors.Errorf("too many calls to policy callback %d", callCounter)
				}
				cb := tc.callbacks[callCounter]
				callCounter++
				return cb(ctx, req)
			})

			_, err = c.Solve(ctx, def, SolveOpt{
				SourcePolicyProvider: p,
			}, nil)
			if tc.expectedError != "" {
				require.Error(t, err)
				require.Contains(t, err.Error(), tc.expectedError)
				return
			}
			require.NoError(t, err)

			require.Equal(t, len(tc.callbacks), callCounter, "not all policy callbacks were called")
		})
	}
}

func testSourceMetaPolicySession(t *testing.T, sb integration.Sandbox) {
	requiresLinux(t)

	ctx := sb.Context()

	c, err := New(ctx, sb.Address())
	require.NoError(t, err)
	defer c.Close()

	type tcase struct {
		name          string
		source        func() (*opspb.SourceOp, sourceresolver.Opt)
		callbacks     []policysession.PolicyCallback
		expectedError string
	}
	tcases := []tcase{
		{
			name: "basic alpine",
			source: func() (*opspb.SourceOp, sourceresolver.Opt) {
				p := platforms.DefaultSpec()
				return &opspb.SourceOp{
						Identifier: "docker-image://docker.io/library/alpine:latest",
					}, sourceresolver.Opt{
						ImageOpt: &sourceresolver.ResolveImageOpt{
							Platform: &p,
						},
					}
			},
			callbacks: []policysession.PolicyCallback{
				func(ctx context.Context, req *policysession.CheckPolicyRequest) (*policysession.DecisionResponse, *pb.ResolveSourceMetaRequest, error) {
					require.Equal(t, runtime.GOOS, req.Platform.OS)
					require.Equal(t, runtime.GOARCH, req.Platform.Architecture)

					require.Equal(t, "docker-image://docker.io/library/alpine:latest", req.Source.Source.Identifier)
					return &policysession.DecisionResponse{
						Action: sourcepolicypb.PolicyAction_ALLOW,
					}, nil, nil
				},
			},
		},
		{
			name: "alpine denied",
			source: func() (*opspb.SourceOp, sourceresolver.Opt) {
				return &opspb.SourceOp{
					Identifier: "docker-image://docker.io/library/alpine:latest",
				}, sourceresolver.Opt{}
			},
			callbacks: []policysession.PolicyCallback{
				func(ctx context.Context, req *policysession.CheckPolicyRequest) (*policysession.DecisionResponse, *pb.ResolveSourceMetaRequest, error) {
					require.Equal(t, "docker-image://docker.io/library/alpine:latest", req.Source.Source.Identifier)
					return nil, nil, errors.New("policy denied")
				},
			},
			expectedError: "policy denied",
		},
	}

	for _, tc := range tcases {
		t.Run(tc.name, func(t *testing.T) {
			callCounter := 0

			p := policysession.NewPolicyProvider(func(ctx context.Context, req *policysession.CheckPolicyRequest) (*policysession.DecisionResponse, *pb.ResolveSourceMetaRequest, error) {
				if callCounter >= len(tc.callbacks) {
					return nil, nil, errors.Errorf("too many calls to policy callback %d", callCounter)
				}
				cb := tc.callbacks[callCounter]
				callCounter++
				return cb(ctx, req)
			})
			_, err = c.Build(ctx, SolveOpt{
				SourcePolicyProvider: p,
			}, "test", func(ctx context.Context, c gateway.Client) (*gateway.Result, error) {
				sop, opts := tc.source()
				_, err = c.ResolveSourceMetadata(ctx, sop, opts)
				return nil, err
			}, nil)

			if tc.expectedError != "" {
				require.Error(t, err)
				require.Contains(t, err.Error(), tc.expectedError)
				return
			}
			require.NoError(t, err)

			require.Equal(t, len(tc.callbacks), callCounter, "not all policy callbacks were called")
		})
	}
}

func testSourcePolicyParallelSession(t *testing.T, sb integration.Sandbox) {
	requiresLinux(t)

	ctx := sb.Context()

	c, err := New(ctx, sb.Address())
	require.NoError(t, err)
	defer c.Close()

	def, err := llb.Image("alpine").File(llb.Copy(llb.Image("busybox"), "/etc/passwd", "passwd2")).Marshal(ctx)
	require.NoError(t, err)

	countAlpine := 0
	countBusybox := 0
	waitBusyboxStart := make(chan struct{})
	waitAlpineDone := make(chan struct{})

	p := policysession.NewPolicyProvider(func(ctx context.Context, req *policysession.CheckPolicyRequest) (*policysession.DecisionResponse, *pb.ResolveSourceMetaRequest, error) {
		switch req.Source.Source.Identifier {
		case "docker-image://docker.io/library/alpine:latest":
			switch countAlpine {
			case 0:
				<-waitBusyboxStart
				require.Nil(t, req.Source.Image)
				countAlpine++
				return nil, &pb.ResolveSourceMetaRequest{
					Source:   req.Source.Source,
					Platform: req.Platform,
				}, nil
			case 1:
				require.NotNil(t, req.Source.Image)
				require.True(t, strings.HasPrefix(req.Source.Image.Digest, "sha256:"))
				countAlpine++
				close(waitAlpineDone)
				return &policysession.DecisionResponse{
					Action: sourcepolicypb.PolicyAction_ALLOW,
				}, nil, nil
			default:
				require.Fail(t, "too many calls for alpine")
			}
		case "docker-image://docker.io/library/busybox:latest":
			time.Sleep(200 * time.Millisecond)
			close(waitBusyboxStart)
			countBusybox++
			<-waitAlpineDone
			return &policysession.DecisionResponse{
				Action: sourcepolicypb.PolicyAction_ALLOW,
			}, nil, nil
		}
		return nil, nil, errors.Errorf("unexpected source %q", req.Source.Source.Identifier)
	})

	_, err = c.Solve(ctx, def, SolveOpt{
		SourcePolicyProvider: p,
	}, nil)
	require.NoError(t, err)

	require.Equal(t, 2, countAlpine)
	require.Equal(t, 1, countBusybox)
}

func testSourcePolicySignedCommit(t *testing.T, sb integration.Sandbox) {
	requiresLinux(t)
	ctx := sb.Context()
	c, err := New(ctx, sb.Address())
	require.NoError(t, err)
	defer c.Close()

	signFixturesPath, ok := os.LookupEnv("BUILDKIT_TEST_SIGN_FIXTURES")
	if !ok {
		t.Skip("missing BUILDKIT_TEST_SIGN_FIXTURES")
	}

	withSign := func(user, method string) []string {
		return []string{
			"GIT_CONFIG_GLOBAL=" + filepath.Join(signFixturesPath, user+"."+method+".gitconfig"),
		}
	}

	gitDir := t.TempDir()
	gitCommands := []string{
		"git init",
		"git config --local user.email test",
		"git config --local user.name test",
		"echo a > a",
		"git add a",
		"git commit -m a",
		"git tag -a v0.1 -m v0.1",
	}
	err = runInDir(gitDir, gitCommands...)
	require.NoError(t, err)
	gitCommands = []string{
		"echo b > b",
		"git add b",
		"git commit -m b",
		"git checkout -B v2",
	}
	err = runInDirEnv(gitDir, withSign("user1", "gpg"), gitCommands...)
	require.NoError(t, err)
	gitCommands = []string{
		"git tag -s -a v2.0 -m v2.0-tag",
		"git update-server-info",
	}
	err = runInDirEnv(gitDir, withSign("user2", "ssh"), gitCommands...)
	require.NoError(t, err)

	server := httptest.NewServer(http.FileServer(http.Dir(filepath.Clean(gitDir))))
	defer server.Close()

	pubKeyUser1gpg, err := os.ReadFile(filepath.Join(signFixturesPath, "user1.gpg.pub"))
	require.NoError(t, err)

	pubKeyUser2ssh, err := os.ReadFile(filepath.Join(signFixturesPath, "user2.ssh.pub"))
	require.NoError(t, err)

	type testCase struct {
		state       func() llb.State
		name        string
		srcPol      *sourcepolicypb.Policy
		expectedErr string
	}

	gitURL := "git://" + strings.TrimPrefix(server.URL, "http://") + "/.git"

	tests := []testCase{
		{
			name: "unsigned commit fails",
			srcPol: &sourcepolicypb.Policy{
				Rules: []*sourcepolicypb.Rule{
					{
						Action: sourcepolicypb.PolicyAction_CONVERT,
						Selector: &sourcepolicypb.Selector{
							Identifier: gitURL + "#v0.1",
						},
						Updates: &sourcepolicypb.Update{
							Identifier: gitURL + "#v0.1",
							Attrs: map[string]string{
								"git.sig.pubkey": string(pubKeyUser1gpg),
							},
						},
					},
				},
			},
			state: func() llb.State {
				return llb.Git(server.URL+"/.git", "", llb.GitRef("v0.1"))
			},
			expectedErr: "git object is not signed",
		},
		{
			name: "valid gpg signature for branch",
			srcPol: &sourcepolicypb.Policy{
				Rules: []*sourcepolicypb.Rule{
					{
						Action: sourcepolicypb.PolicyAction_CONVERT,
						Selector: &sourcepolicypb.Selector{
							Identifier: gitURL + "#v2",
						},
						Updates: &sourcepolicypb.Update{
							Identifier: gitURL + "#v2",
							Attrs: map[string]string{
								"git.sig.pubkey":          string(pubKeyUser1gpg),
								"git.sig.rejectexpired":   "true",
								"git.sig.ignoresignedtag": "false",
							},
						},
					},
				},
			},
			state: func() llb.State {
				return llb.Git(server.URL+"/.git", "", llb.GitRef("v2"))
			},
		},
		{
			name: "valid ssh signature for signed tag",
			srcPol: &sourcepolicypb.Policy{
				Rules: []*sourcepolicypb.Rule{
					{
						Action: sourcepolicypb.PolicyAction_CONVERT,
						Selector: &sourcepolicypb.Selector{
							Identifier: gitURL + "#v2.0",
						},
						Updates: &sourcepolicypb.Update{
							Identifier: gitURL + "#v2.0",
							Attrs: map[string]string{
								"git.sig.pubkey":           string(pubKeyUser2ssh),
								"git.sig.requiresignedtag": "true",
								"git.sig.rejectexpired":    "true",
							},
						},
					},
				},
			},
			state: func() llb.State {
				return llb.Git(server.URL+"/.git", "", llb.GitRef("v2.0"))
			},
		},
		{
			name: "invalid ssh signature for signed tag",
			srcPol: &sourcepolicypb.Policy{
				Rules: []*sourcepolicypb.Rule{
					{
						Action: sourcepolicypb.PolicyAction_CONVERT,
						Selector: &sourcepolicypb.Selector{
							Identifier: gitURL + "#v2.0",
						},
						Updates: &sourcepolicypb.Update{
							Identifier: gitURL + "#v2.0",
							Attrs: map[string]string{
								"git.sig.pubkey":           string(pubKeyUser1gpg),
								"git.sig.requiresignedtag": "true",
								"git.sig.rejectexpired":    "true",
							},
						},
					},
				},
			},
			expectedErr: "failed to parse ssh public key",
			state: func() llb.State {
				return llb.Git(server.URL+"/.git", "", llb.GitRef("v2.0"))
			},
		},
		{
			name: "commit ssh signature for signed tag",
			srcPol: &sourcepolicypb.Policy{
				Rules: []*sourcepolicypb.Rule{
					{
						Action: sourcepolicypb.PolicyAction_CONVERT,
						Selector: &sourcepolicypb.Selector{
							Identifier: gitURL + "#v2.0",
						},
						Updates: &sourcepolicypb.Update{
							Identifier: gitURL + "#v2.0",
							Attrs: map[string]string{
								"git.sig.pubkey":           string(pubKeyUser1gpg),
								"git.sig.requiresignedtag": "false",
								"git.sig.rejectexpired":    "true",
							},
						},
					},
				},
			},
			state: func() llb.State {
				return llb.Git(server.URL+"/.git", "", llb.GitRef("v2.0"))
			},
		},
		{
			name: "invalid tag signature for commit",
			srcPol: &sourcepolicypb.Policy{
				Rules: []*sourcepolicypb.Rule{
					{
						Action: sourcepolicypb.PolicyAction_CONVERT,
						Selector: &sourcepolicypb.Selector{
							Identifier: gitURL + "#v2.0",
						},
						Updates: &sourcepolicypb.Update{
							Identifier: gitURL + "#v2.0",
							Attrs: map[string]string{
								"git.sig.pubkey":          string(pubKeyUser2ssh),
								"git.sig.rejectexpired":   "true",
								"git.sig.ignoresignedtag": "true",
							},
						},
					},
				},
			},
			expectedErr: "failed to read armored public key: openpgp",
			state: func() llb.State {
				return llb.Git(server.URL+"/.git", "", llb.GitRef("v2.0"))
			},
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			frontend := func(ctx context.Context, c gateway.Client) (*gateway.Result, error) {
				st := llb.Scratch().File(
					llb.Copy(tt.state(), "a", "/a2"),
				)
				def, err := st.Marshal(sb.Context())
				if err != nil {
					return nil, err
				}
				return c.Solve(ctx, gateway.SolveRequest{
					Definition: def.ToPB(),
				})
			}

			_, err := c.Build(sb.Context(), SolveOpt{
				SourcePolicy: tt.srcPol,
			}, "", frontend, nil)
			if tt.expectedErr == "" {
				require.NoError(t, err, "test case %q failed", tt.name)
				return
			}
			require.ErrorContains(t, err, tt.expectedErr, "test case %q failed", tt.name)
		})
	}

	// session policy based test cases

	type tcase struct {
		name          string
		state         func() llb.State
		callbacks     []policysession.PolicyCallback
		expectedError string
	}

	tcases := []tcase{
		{
			name:  "gitchecksum",
			state: func() llb.State { return llb.Git(server.URL+"/.git", "", llb.GitRef("v2.0")) },
			callbacks: []policysession.PolicyCallback{
				func(ctx context.Context, req *policysession.CheckPolicyRequest) (*policysession.DecisionResponse, *pb.ResolveSourceMetaRequest, error) {
					require.Equal(t, gitURL+"#v2.0", req.Source.Source.Identifier)
					require.Nil(t, req.Source.Git)
					return nil, &pb.ResolveSourceMetaRequest{
						Source:   req.Source.Source,
						Platform: req.Platform,
					}, nil
				},
				func(ctx context.Context, req *policysession.CheckPolicyRequest) (*policysession.DecisionResponse, *pb.ResolveSourceMetaRequest, error) {
					require.Equal(t, gitURL+"#v2.0", req.Source.Source.Identifier)
					require.NotNil(t, req.Source.Git)
					require.Len(t, req.Source.Git.Checksum, 40)
					require.Len(t, req.Source.Git.CommitChecksum, 40)
					require.NotEqual(t, req.Source.Git.Checksum, req.Source.Git.CommitChecksum)
					require.Nil(t, req.Source.Git.CommitObject)
					return &policysession.DecisionResponse{
						Action: sourcepolicypb.PolicyAction_ALLOW,
					}, nil, nil
				},
			},
		},
		{
			name:  "gitobjects",
			state: func() llb.State { return llb.Git(server.URL+"/.git", "", llb.GitRef("v2.0")) },
			callbacks: []policysession.PolicyCallback{
				func(ctx context.Context, req *policysession.CheckPolicyRequest) (*policysession.DecisionResponse, *pb.ResolveSourceMetaRequest, error) {
					require.Equal(t, gitURL+"#v2.0", req.Source.Source.Identifier)
					require.Nil(t, req.Source.Git)
					return nil, &pb.ResolveSourceMetaRequest{
						Source:   req.Source.Source,
						Platform: req.Platform,
						Git: &pb.ResolveSourceGitRequest{
							ReturnObject: true,
						},
					}, nil
				},
				func(ctx context.Context, req *policysession.CheckPolicyRequest) (*policysession.DecisionResponse, *pb.ResolveSourceMetaRequest, error) {
					require.Equal(t, gitURL+"#v2.0", req.Source.Source.Identifier)
					require.NotNil(t, req.Source.Git)
					require.Len(t, req.Source.Git.Checksum, 40)
					require.Len(t, req.Source.Git.CommitChecksum, 40)
					require.NotEqual(t, req.Source.Git.Checksum, req.Source.Git.CommitChecksum)
					require.NotNil(t, req.Source.Git.CommitObject)
					require.Greater(t, len(req.Source.Git.CommitObject), 50)
					return &policysession.DecisionResponse{
						Action: sourcepolicypb.PolicyAction_ALLOW,
					}, nil, nil
				},
			},
		},
	}

	for _, tc := range tcases {
		t.Run(tc.name, func(t *testing.T) {
			st := tc.state()
			def, err := st.Marshal(ctx)
			require.NoError(t, err)

			callCounter := 0

			p := policysession.NewPolicyProvider(func(ctx context.Context, req *policysession.CheckPolicyRequest) (*policysession.DecisionResponse, *pb.ResolveSourceMetaRequest, error) {
				if callCounter >= len(tc.callbacks) {
					return nil, nil, errors.Errorf("too many calls to policy callback %d", callCounter)
				}
				cb := tc.callbacks[callCounter]
				callCounter++
				return cb(ctx, req)
			})

			_, err = c.Solve(ctx, def, SolveOpt{
				SourcePolicyProvider: p,
			}, nil)
			if tc.expectedError != "" {
				require.Error(t, err)
				require.Contains(t, err.Error(), tc.expectedError)
				return
			}
			require.NoError(t, err)

			require.Equal(t, len(tc.callbacks), callCounter, "not all policy callbacks were called")
		})
	}
}
