diff --git a/api.go b/api.go index b6db1beb..9204411d 100644 --- a/api.go +++ b/api.go @@ -638,7 +638,7 @@ func (r *Repository) UnmarshalJSON(data []byte) error { r.ID = uint32(id) } - if v, ok := repo.RawConfig["tenantid"]; ok { + if v, ok := repo.RawConfig["tenantID"]; ok { id, _ := strconv.ParseInt(v, 10, 64) r.TenantID = int(id) } diff --git a/build/builder.go b/build/builder.go index b922e92b..2e7c6ef3 100644 --- a/build/builder.go +++ b/build/builder.go @@ -118,8 +118,8 @@ type Options struct { // Note: heap checking is "best effort", and it's possible for the process to OOM without triggering the heap profile. HeapProfileTriggerBytes uint64 - // UseSourcegraphIDForName is true if we want to use the Sourcegraph ID as prefix for the shard name. - UseSourcegraphIDForName bool + // ShardPrefix is the prefix of the shard. If empty, the repository name is used. + ShardPrefix string } // HashOptions contains only the options in Options that upon modification leads to IndexState of IndexStateMismatch during the next index building. @@ -184,6 +184,7 @@ func (o *Options) Flags(fs *flag.FlagSet) { fs.StringVar(&o.IndexDir, "index", x.IndexDir, "directory for search indices") fs.BoolVar(&o.CTagsMustSucceed, "require_ctags", x.CTagsMustSucceed, "If set, ctags calls must succeed.") fs.Var(largeFilesFlag{o}, "large_file", "A glob pattern where matching files are to be index regardless of their size. You can add multiple patterns by setting this more than once.") + fs.StringVar(&o.ShardPrefix, "shard_prefix", x.ShardPrefix, "the prefix of the shard. If empty, the repository name is used.") // Sourcegraph specific fs.BoolVar(&o.DisableCTags, "disable_ctags", x.DisableCTags, "If set, ctags will not be called.") @@ -231,6 +232,10 @@ func (o *Options) Args() []string { args = append(args, "-shard_merging") } + if o.ShardPrefix != "" { + args = append(args, "-shard_prefix", o.ShardPrefix) + } + return args } @@ -341,8 +346,8 @@ func (o *Options) shardName(n int) string { func (o *Options) shardNameVersion(version, n int) string { var prefix string - if o.UseSourcegraphIDForName { - prefix = fmt.Sprintf("%d", o.RepositoryDescription.ID) + if o.ShardPrefix != "" { + prefix = o.ShardPrefix } else { prefix = url.QueryEscape(o.RepositoryDescription.Name) } diff --git a/build/e2e_test.go b/build/e2e_test.go index 1eb5acff..b211cb3e 100644 --- a/build/e2e_test.go +++ b/build/e2e_test.go @@ -25,6 +25,7 @@ import ( "reflect" "runtime" "sort" + "strconv" "strings" "testing" "time" @@ -32,7 +33,11 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/grafana/regexp" + "github.com/stretchr/testify/require" + "github.com/sourcegraph/zoekt" + "github.com/sourcegraph/zoekt/internal/tenant" + "github.com/sourcegraph/zoekt/internal/tenant/tenanttest" "github.com/sourcegraph/zoekt/query" "github.com/sourcegraph/zoekt/shards" ) @@ -153,6 +158,88 @@ func TestBasic(t *testing.T) { }) } +func TestBasicTenant(t *testing.T) { + tenanttest.MockEnforce(t) + + dir := t.TempDir() + + ctx1 := tenanttest.NewTestContext() + tnt1, err := tenant.FromContext(ctx1) + require.NoError(t, err) + + opts := Options{ + IndexDir: dir, + ShardMax: 1024, + RepositoryDescription: zoekt.Repository{ + Name: "repo", + RawConfig: map[string]string{"tenantID": strconv.Itoa(tnt1.ID())}, + }, + Parallelism: 2, + SizeMax: 1 << 20, + } + + b, err := NewBuilder(opts) + if err != nil { + t.Fatalf("NewBuilder: %v", err) + } + + for i := 0; i < 4; i++ { + s := fmt.Sprintf("%d", i) + if err := b.AddFile("F"+s, []byte(strings.Repeat(s, 1000))); err != nil { + t.Fatal(err) + } + } + + if err := b.Finish(); err != nil { + t.Errorf("Finish: %v", err) + } + + fs, _ := filepath.Glob(dir + "/*.zoekt") + if len(fs) <= 1 { + t.Fatalf("want multiple shards, got %v", fs) + } + + _, md0, err := zoekt.ReadMetadataPath(fs[0]) + if err != nil { + t.Fatal(err) + } + for _, f := range fs[1:] { + _, md, err := zoekt.ReadMetadataPath(f) + if err != nil { + t.Fatal(err) + } + if md.IndexTime != md0.IndexTime { + t.Fatalf("wanted identical time stamps but got %v!=%v", md.IndexTime, md0.IndexTime) + } + if md.ID != md0.ID { + t.Fatalf("wanted identical IDs but got %s!=%s", md.ID, md0.ID) + } + } + + ss, err := shards.NewDirectorySearcher(dir) + if err != nil { + t.Fatalf("NewDirectorySearcher(%s): %v", dir, err) + } + defer ss.Close() + + q, err := query.Parse("111") + if err != nil { + t.Fatalf("Parse(111): %v", err) + } + + var sOpts zoekt.SearchOptions + // Tenant 1 has access to the repo + result, err := ss.Search(ctx1, q, &sOpts) + require.NoError(t, err) + require.Len(t, result.Files, 1) + + // Tenant 2 does not have access to the repo + ctx2 := tenanttest.NewTestContext() + result, err = ss.Search(ctx2, q, &sOpts) + require.NoError(t, err) + require.Len(t, result.Files, 0) +} + // retryTest will retry f until min(t.Deadline(), time.Minute). It returns // once f doesn't call fatalf. func retryTest(t *testing.T, f func(fatalf func(format string, args ...interface{}))) { diff --git a/cmd/zoekt-git-index/main.go b/cmd/zoekt-git-index/main.go index 7bb0d356..2f1b93b3 100644 --- a/cmd/zoekt-git-index/main.go +++ b/cmd/zoekt-git-index/main.go @@ -48,8 +48,6 @@ func run() int { cpuProfile := flag.String("cpuprofile", "", "write cpu profile to `file`") - useSourcegraphIDForName := flag.Bool("use_sourcegraph_id_for_name", false, "use the Sourcegraph ID for the shard name") - flag.Parse() // Tune GOMAXPROCS to match Linux container CPU quota. @@ -77,7 +75,6 @@ func run() int { opts := cmd.OptionsFromFlags() opts.IsDelta = *isDelta - opts.UseSourcegraphIDForName = *useSourcegraphIDForName var branches []string if *branchesStr != "" { diff --git a/cmd/zoekt-sourcegraph-indexserver/index.go b/cmd/zoekt-sourcegraph-indexserver/index.go index f2ef6dad..650e8c01 100644 --- a/cmd/zoekt-sourcegraph-indexserver/index.go +++ b/cmd/zoekt-sourcegraph-indexserver/index.go @@ -17,12 +17,11 @@ import ( "strings" "time" + sglog "github.com/sourcegraph/log" + "github.com/sourcegraph/zoekt" "github.com/sourcegraph/zoekt/build" "github.com/sourcegraph/zoekt/ctags" - "github.com/sourcegraph/zoekt/internal/tenant" - - sglog "github.com/sourcegraph/log" ) const defaultIndexingTimeout = 1*time.Hour + 30*time.Minute @@ -100,18 +99,16 @@ type indexArgs struct { // ShardMerging is true if we want zoekt-git-index to respect compound shards. ShardMerging bool - // UseSourcegraphIDForName is true if we want to use the Sourcegraph ID as prefix for the shard name. - UseSourcegraphIDForName bool + // IdBasedNames is true if we want to use ID-based names as prefix for the shard name. + IdBasedNames bool } // BuildOptions returns a build.Options represented by indexArgs. Note: it // doesn't set fields like repository/branch. func (o *indexArgs) BuildOptions() *build.Options { - - // Default to tenant 1 if no tenant is set. - tenantID := o.TenantID - if o.TenantID < 1 { - tenantID = 1 + shardPrefix := "" + if o.IdBasedNames { + shardPrefix = fmt.Sprintf("%09d_%09d", o.TenantID, o.IndexOptions.RepoID) } return &build.Options{ @@ -132,7 +129,7 @@ func (o *indexArgs) BuildOptions() *build.Options { "archived": marshalBool(o.Archived), // Calculate repo rank based on the latest commit date. "latestCommitDate": "1", - "tenantid": strconv.Itoa(tenantID), + "tenantID": strconv.Itoa(o.TenantID), }, }, IndexDir: o.IndexDir, @@ -147,7 +144,7 @@ func (o *indexArgs) BuildOptions() *build.Options { ShardMerging: o.ShardMerging, - UseSourcegraphIDForName: o.UseSourcegraphIDForName, + ShardPrefix: shardPrefix, } } @@ -263,7 +260,7 @@ func fetchRepo(ctx context.Context, gitDir string, o *indexArgs, c gitIndexConfi "-C", gitDir, "-c", "protocol.version=2", "-c", "http.extraHeader=X-Sourcegraph-Actor-UID: internal", - "-c", "http.extraHeader=" + tenant.HttpExtraHeader(o.TenantID), + "-c", "http.extraHeader=X-Sourcegraph-Tenant-ID: " + strconv.Itoa(o.TenantID), "fetch", "--depth=1", "--no-tags", } @@ -410,10 +407,6 @@ func indexRepo(ctx context.Context, gitDir string, sourcegraph Sourcegraph, o *i args = append(args, "-delta_threshold", strconv.FormatUint(o.DeltaShardNumberFallbackThreshold, 10)) } - if o.UseSourcegraphIDForName { - args = append(args, "-use_sourcegraph_id_for_name") - } - if len(o.LanguageMap) > 0 { var languageMap []string for language, parser := range o.LanguageMap { diff --git a/cmd/zoekt-sourcegraph-indexserver/index_test.go b/cmd/zoekt-sourcegraph-indexserver/index_test.go index 22faa5b9..fabf95fe 100644 --- a/cmd/zoekt-sourcegraph-indexserver/index_test.go +++ b/cmd/zoekt-sourcegraph-indexserver/index_test.go @@ -487,11 +487,12 @@ func TestIndex(t *testing.T) { Name: "test/repo", CloneURL: "http://api.test/.internal/git/test/repo", Branches: []zoekt.RepositoryBranch{{Name: "HEAD", Version: "deadbeef"}}, + TenantID: 42, }, }, want: []string{ "git -c init.defaultBranch=nonExistentBranchBB0FOFCH32 init --bare $TMPDIR/test%2Frepo.git", - "git -C $TMPDIR/test%2Frepo.git -c protocol.version=2 -c http.extraHeader=X-Sourcegraph-Actor-UID: internal -c http.extraHeader=X-Sourcegraph-Tenant-ID: 1 fetch --depth=1 --no-tags --filter=blob:limit=1m http://api.test/.internal/git/test/repo deadbeef", + "git -C $TMPDIR/test%2Frepo.git -c protocol.version=2 -c http.extraHeader=X-Sourcegraph-Actor-UID: internal -c http.extraHeader=X-Sourcegraph-Tenant-ID: 42 fetch --depth=1 --no-tags --filter=blob:limit=1m http://api.test/.internal/git/test/repo deadbeef", "git -C $TMPDIR/test%2Frepo.git update-ref HEAD deadbeef", "git -C $TMPDIR/test%2Frepo.git config zoekt.archived 0", "git -C $TMPDIR/test%2Frepo.git config zoekt.fork 0", @@ -500,7 +501,7 @@ func TestIndex(t *testing.T) { "git -C $TMPDIR/test%2Frepo.git config zoekt.priority 0", "git -C $TMPDIR/test%2Frepo.git config zoekt.public 0", "git -C $TMPDIR/test%2Frepo.git config zoekt.repoid 0", - "git -C $TMPDIR/test%2Frepo.git config zoekt.tenantid 1", + "git -C $TMPDIR/test%2Frepo.git config zoekt.tenantID 42", "zoekt-git-index -submodules=false -branches HEAD -disable_ctags $TMPDIR/test%2Frepo.git", }, }, { @@ -511,6 +512,7 @@ func TestIndex(t *testing.T) { CloneURL: "http://api.test/.internal/git/test/repo", Branches: []zoekt.RepositoryBranch{{Name: "HEAD", Version: "deadbeef"}}, RepoID: 123, + TenantID: 1, }, }, want: []string{ @@ -524,7 +526,7 @@ func TestIndex(t *testing.T) { "git -C $TMPDIR/test%2Frepo.git config zoekt.priority 0", "git -C $TMPDIR/test%2Frepo.git config zoekt.public 0", "git -C $TMPDIR/test%2Frepo.git config zoekt.repoid 123", - "git -C $TMPDIR/test%2Frepo.git config zoekt.tenantid 1", + "git -C $TMPDIR/test%2Frepo.git config zoekt.tenantID 1", "zoekt-git-index -submodules=false -branches HEAD -disable_ctags $TMPDIR/test%2Frepo.git", }, }, { @@ -543,6 +545,7 @@ func TestIndex(t *testing.T) { {Name: "HEAD", Version: "deadbeef"}, {Name: "dev", Version: "feebdaed"}, // ignored for archive }, + TenantID: 1, }, }, want: []string{ @@ -557,7 +560,7 @@ func TestIndex(t *testing.T) { "git -C $TMPDIR/test%2Frepo.git config zoekt.priority 0", "git -C $TMPDIR/test%2Frepo.git config zoekt.public 0", "git -C $TMPDIR/test%2Frepo.git config zoekt.repoid 0", - "git -C $TMPDIR/test%2Frepo.git config zoekt.tenantid 1", + "git -C $TMPDIR/test%2Frepo.git config zoekt.tenantID 1", "zoekt-git-index -submodules=false -incremental -branches HEAD,dev " + "-file_limit 123 -parallelism 4 -index /data/index -require_ctags -large_file foo -large_file bar " + "$TMPDIR/test%2Frepo.git", @@ -581,6 +584,7 @@ func TestIndex(t *testing.T) { {Name: "dev", Version: "feebdaed"}, {Name: "release", Version: "12345678"}, }, + TenantID: 1, }, DeltaShardNumberFallbackThreshold: 22, }, @@ -606,7 +610,7 @@ func TestIndex(t *testing.T) { "git -C $TMPDIR/test%2Frepo.git config zoekt.priority 0", "git -C $TMPDIR/test%2Frepo.git config zoekt.public 0", "git -C $TMPDIR/test%2Frepo.git config zoekt.repoid 0", - "git -C $TMPDIR/test%2Frepo.git config zoekt.tenantid 1", + "git -C $TMPDIR/test%2Frepo.git config zoekt.tenantID 1", "zoekt-git-index -submodules=false -incremental -branches HEAD,dev,release " + "-delta -delta_threshold 22 -file_limit 123 -parallelism 4 -index /data/index -require_ctags -large_file foo -large_file bar " + "$TMPDIR/test%2Frepo.git", diff --git a/cmd/zoekt-sourcegraph-indexserver/main.go b/cmd/zoekt-sourcegraph-indexserver/main.go index 839302aa..d2d8f5d8 100644 --- a/cmd/zoekt-sourcegraph-indexserver/main.go +++ b/cmd/zoekt-sourcegraph-indexserver/main.go @@ -53,6 +53,7 @@ import ( "github.com/sourcegraph/zoekt/grpc/internalerrs" "github.com/sourcegraph/zoekt/grpc/messagesize" "github.com/sourcegraph/zoekt/internal/profiler" + "github.com/sourcegraph/zoekt/internal/tenant" ) var ( @@ -213,7 +214,7 @@ type Server struct { // timeout defines how long the index server waits before killing an indexing job. timeout time.Duration - useSourcegraphIDForName bool + idBasedNames bool } var debug = log.New(io.Discard, "", log.LstdFlags) @@ -559,6 +560,11 @@ func (s *Server) Index(args *indexArgs) (state indexState, err error) { tr.Finish() }() + // Sourcegraph should always provide a tenant ID. + if args.TenantID < 1 { + return indexStateFail, tenant.ErrMissingTenant + } + tr.LazyPrintf("branches: %v", args.Branches) if len(args.Branches) == 0 { @@ -673,13 +679,13 @@ func sglogBranches(key string, branches []zoekt.RepositoryBranch) sglog.Field { func (s *Server) indexArgs(opts IndexOptions) *indexArgs { parallelism := s.parallelism(opts, runtime.GOMAXPROCS(0)) return &indexArgs{ - IndexOptions: opts, - IndexDir: s.IndexDir, - Parallelism: parallelism, - Incremental: true, - FileLimit: MaxFileSize, - ShardMerging: s.shardMerging, - UseSourcegraphIDForName: s.useSourcegraphIDForName, + IndexOptions: opts, + IndexDir: s.IndexDir, + Parallelism: parallelism, + Incremental: true, + FileLimit: MaxFileSize, + ShardMerging: s.shardMerging, + IdBasedNames: s.idBasedNames, } } @@ -1253,7 +1259,7 @@ type rootConfig struct { backoffDuration time.Duration maxBackoffDuration time.Duration - useSourcegraphIDForName bool + idBasedNames bool } func (rc *rootConfig) registerRootFlags(fs *flag.FlagSet) { @@ -1266,7 +1272,7 @@ func (rc *rootConfig) registerRootFlags(fs *flag.FlagSet) { fs.Float64Var(&rc.cpuFraction, "cpu_fraction", 1.0, "use this fraction of the cores for indexing.") fs.DurationVar(&rc.backoffDuration, "backoff_duration", getEnvWithDefaultDuration("BACKOFF_DURATION", 10*time.Minute), "for the given duration we backoff from enqueue operations for a repository that's failed its previous indexing attempt. Consecutive failures increase the duration of the delay linearly up to the maxBackoffDuration. A negative value disables indexing backoff.") fs.DurationVar(&rc.maxBackoffDuration, "max_backoff_duration", getEnvWithDefaultDuration("MAX_BACKOFF_DURATION", 120*time.Minute), "the maximum duration to backoff from enqueueing a repo for indexing. A negative value disables indexing backoff.") - fs.BoolVar(&rc.useSourcegraphIDForName, "use_sourcegraph_id_for_name", getEnvWithDefaultBool("SRC_USE_SOURCEGRAPH_ID_FOR_NAME", false), "use the Sourcegraph ID as the name for index shards.") + fs.BoolVar(&rc.idBasedNames, "id_based_names", getEnvWithDefaultBool("ID_BASED_NAMES", false), "use id-based prefixes for shards.") // flags related to shard merging fs.BoolVar(&rc.disableShardMerging, "shard_merging", getEnvWithDefaultBool("SRC_DISABLE_SHARD_MERGING", false), "disable shard merging") @@ -1473,7 +1479,7 @@ func newServer(conf rootConfig) (*Server, error) { CPUCount: cpuCount, queue: *q, shardMerging: !conf.disableShardMerging, - useSourcegraphIDForName: conf.useSourcegraphIDForName, + idBasedNames: conf.idBasedNames, deltaBuildRepositoriesAllowList: deltaBuildRepositoriesAllowList, deltaShardNumberFallbackThreshold: deltaShardNumberFallbackThreshold, repositoriesSkipSymbolsCalculationAllowList: reposShouldSkipSymbolsCalculation, diff --git a/cmd/zoekt-sourcegraph-indexserver/main_test.go b/cmd/zoekt-sourcegraph-indexserver/main_test.go index 806a6928..aebfa6fb 100644 --- a/cmd/zoekt-sourcegraph-indexserver/main_test.go +++ b/cmd/zoekt-sourcegraph-indexserver/main_test.go @@ -15,6 +15,7 @@ import ( sglog "github.com/sourcegraph/log" "github.com/sourcegraph/log/logtest" + "github.com/stretchr/testify/require" "github.com/xeipuuv/gojsonschema" "google.golang.org/grpc" @@ -22,6 +23,7 @@ import ( "github.com/sourcegraph/zoekt" proto "github.com/sourcegraph/zoekt/cmd/zoekt-sourcegraph-indexserver/protos/sourcegraph/zoekt/configuration/v1" + "github.com/sourcegraph/zoekt/internal/tenant" ) func TestServer_defaultArgs(t *testing.T) { @@ -51,6 +53,12 @@ func TestServer_defaultArgs(t *testing.T) { } } +func TestIndexNoTenant(t *testing.T) { + s := &Server{} + _, err := s.Index(&indexArgs{}) + require.ErrorIs(t, err, tenant.ErrMissingTenant) +} + func TestServer_parallelism(t *testing.T) { root, err := url.Parse("http://api.test") if err != nil { diff --git a/internal/tenant/context.go b/internal/tenant/context.go new file mode 100644 index 00000000..fa56f3a8 --- /dev/null +++ b/internal/tenant/context.go @@ -0,0 +1,38 @@ +package tenant + +import ( + "context" + "fmt" + "runtime/pprof" + + "go.uber.org/atomic" + + "github.com/sourcegraph/zoekt/internal/tenant/internal/enforcement" + "github.com/sourcegraph/zoekt/internal/tenant/internal/tenanttype" +) + +var ErrMissingTenant = fmt.Errorf("missing tenant") + +func FromContext(ctx context.Context) (*tenanttype.Tenant, error) { + tnt, ok := tenanttype.GetTenant(ctx) + if !ok { + if pprofMissingTenant != nil { + // We want to track every stack trace, so need a unique value for the event + eventValue := pprofUniqID.Add(1) + + // skip stack for Add and this function (2). + pprofMissingTenant.Add(eventValue, 2) + } + + return nil, ErrMissingTenant + } + return tnt, nil +} + +var pprofUniqID atomic.Int64 +var pprofMissingTenant = func() *pprof.Profile { + if !enforcement.ShouldLogNoTenant() { + return nil + } + return pprof.NewProfile("missing_tenant") +}() diff --git a/internal/tenant/grpc.go b/internal/tenant/grpc.go index 56571aed..774c216d 100644 --- a/internal/tenant/grpc.go +++ b/internal/tenant/grpc.go @@ -34,8 +34,8 @@ var _ propagator.Propagator = &Propagator{} func (Propagator) FromContext(ctx context.Context) metadata.MD { md := make(metadata.MD) - tenant, err := tenanttype.FromContext(ctx) - if err != nil { + tenant, ok := tenanttype.GetTenant(ctx) + if !ok { md.Append(headerKeyTenantID, headerValueNoTenant) } else { md.Append(headerKeyTenantID, strconv.Itoa(tenant.ID())) @@ -65,7 +65,7 @@ func (Propagator) InjectContext(ctx context.Context, md metadata.MD) (context.Co // UnaryServerInterceptor is a grpc.UnaryServerInterceptor that injects the tenant ID // from the context into pprof labels. func UnaryServerInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (response any, err error) { - if tnt, err := tenanttype.FromContext(ctx); err == nil { + if tnt, ok := tenanttype.GetTenant(ctx); ok { defer pprof.SetGoroutineLabels(ctx) ctx = pprof.WithLabels(ctx, pprof.Labels("tenant", tenanttype.Marshal(tnt))) pprof.SetGoroutineLabels(ctx) @@ -77,7 +77,7 @@ func UnaryServerInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInf // StreamServerInterceptor is a grpc.StreamServerInterceptor that injects the tenant ID // from the context into pprof labels. func StreamServerInterceptor(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - if tnt, err := tenanttype.FromContext(ss.Context()); err == nil { + if tnt, ok := tenanttype.GetTenant(ss.Context()); ok { ctx := ss.Context() defer pprof.SetGoroutineLabels(ctx) ctx = pprof.WithLabels(ctx, pprof.Labels("tenant", tenanttype.Marshal(tnt))) diff --git a/internal/tenant/http.go b/internal/tenant/http.go deleted file mode 100644 index 1b2d7426..00000000 --- a/internal/tenant/http.go +++ /dev/null @@ -1,14 +0,0 @@ -package tenant - -import ( - "strconv" -) - -// HttpExtraHeader returns header we send to gitserver given a tenant context. -func HttpExtraHeader(tenantID int) string { - key := headerKeyTenantID + ": " - if !EnforceTenant() { - return key + "1" - } - return key + strconv.Itoa(tenantID) -} diff --git a/internal/tenant/internal/tenanttype/type.go b/internal/tenant/internal/tenanttype/type.go index a179e82b..f99c61aa 100644 --- a/internal/tenant/internal/tenanttype/type.go +++ b/internal/tenant/internal/tenanttype/type.go @@ -3,12 +3,7 @@ package tenanttype import ( "context" "fmt" - "runtime/pprof" "strconv" - - "github.com/sourcegraph/zoekt/internal/tenant/internal/enforcement" - - "go.uber.org/atomic" ) type Tenant struct { @@ -29,37 +24,11 @@ func WithTenant(ctx context.Context, tenant *Tenant) context.Context { return context.WithValue(ctx, tenantKey, tenant) } -const skipLogging contextKey = iota - -var ErrNoTenantInContext = fmt.Errorf("no tenant in context") - -func FromContext(ctx context.Context) (*Tenant, error) { +func GetTenant(ctx context.Context) (*Tenant, bool) { tnt, ok := ctx.Value(tenantKey).(*Tenant) - if !ok { - if pprofMissingTenant != nil { - _, ok := ctx.Value(skipLogging).(contextKey) - if !ok { - // We want to track every stack trace, so need a unique value for the event - eventValue := pprofUniqID.Add(1) - - // skip stack for Add and this function (2). - pprofMissingTenant.Add(eventValue, 2) - } - } - - return nil, ErrNoTenantInContext - } - return tnt, nil + return tnt, ok } -var pprofUniqID atomic.Int64 -var pprofMissingTenant = func() *pprof.Profile { - if !enforcement.ShouldLogNoTenant() { - return nil - } - return pprof.NewProfile("missing_tenant") -}() - func Unmarshal(s string) (*Tenant, error) { id, err := strconv.Atoi(s) if err != nil { diff --git a/internal/tenant/internal/tenanttype/type_test.go b/internal/tenant/internal/tenanttype/type_test.go index 2821b823..4a24a116 100644 --- a/internal/tenant/internal/tenanttype/type_test.go +++ b/internal/tenant/internal/tenanttype/type_test.go @@ -11,13 +11,13 @@ func TestTenantRoundtrip(t *testing.T) { ctx := context.Background() tenantID := 42 ctxWithTenant := WithTenant(ctx, &Tenant{tenantID}) - tenant, err := FromContext(ctxWithTenant) - require.NoError(t, err) + tenant, ok := GetTenant(ctxWithTenant) + require.True(t, ok) require.Equal(t, tenantID, tenant.ID()) } func TestFromContextWithoutTenant(t *testing.T) { ctx := context.Background() - _, err := FromContext(ctx) - require.Equal(t, ErrNoTenantInContext, err) + _, ok := GetTenant(ctx) + require.False(t, ok) } diff --git a/internal/tenant/query.go b/internal/tenant/query.go index 749bce53..3dbfd48f 100644 --- a/internal/tenant/query.go +++ b/internal/tenant/query.go @@ -2,8 +2,6 @@ package tenant import ( "context" - - "github.com/sourcegraph/zoekt/internal/tenant/internal/tenanttype" ) // EqualsID returns true if the tenant ID in the context matches the @@ -12,7 +10,7 @@ func EqualsID(ctx context.Context, id int) bool { if !EnforceTenant() { return true } - t, err := tenanttype.FromContext(ctx) + t, err := FromContext(ctx) if err != nil { return false } diff --git a/internal/tenant/tenanttest/tenanttest.go b/internal/tenant/tenanttest/tenanttest.go index 1e58b541..3d42a2d3 100644 --- a/internal/tenant/tenanttest/tenanttest.go +++ b/internal/tenant/tenanttest/tenanttest.go @@ -1,9 +1,13 @@ package tenanttest import ( + "context" "testing" + "go.uber.org/atomic" + "github.com/sourcegraph/zoekt/internal/tenant/internal/enforcement" + "github.com/sourcegraph/zoekt/internal/tenant/internal/tenanttype" ) func MockEnforce(t *testing.T) { @@ -13,7 +17,29 @@ func MockEnforce(t *testing.T) { old := enforcement.EnforcementMode.Load() t.Cleanup(func() { enforcement.EnforcementMode.Store(old) + ResetTestTenants() }) enforcement.EnforcementMode.Store("strict") } + +// TestTenantCounter is a counter that is tracks tenants created from NewTestContext(). +var TestTenantCounter atomic.Int64 + +func NewTestContext() context.Context { + return tenanttype.WithTenant(context.Background(), mustTenantFromID(int(TestTenantCounter.Inc()))) +} + +// ResetTestTenants resets the test tenant counter that tracks the tenants +// created from NewTestContext(). +func ResetTestTenants() { + TestTenantCounter.Store(0) +} + +func mustTenantFromID(id int) *tenanttype.Tenant { + tenant, err := tenanttype.FromID(id) + if err != nil { + panic(err) + } + return tenant +}