Skip to content

Commit

Permalink
Refactor Trillian client with exported methods (#1454)
Browse files Browse the repository at this point in the history
This allows the Trillian client to be used in other parts of the
codebase besides the api package. The changes included exporting the
Response fields and all but one of the struct methods. Also removed
ranges since these weren't used outside the api package.

Signed-off-by: Hayden Blauzvern <hblauzvern@google.com>
  • Loading branch information
haydentherapper authored May 2, 2023
1 parent 5d6e972 commit 46ac0b2
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 181 deletions.
3 changes: 2 additions & 1 deletion pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/sigstore/rekor/pkg/sharding"
"github.com/sigstore/rekor/pkg/signer"
"github.com/sigstore/rekor/pkg/storage"
"github.com/sigstore/rekor/pkg/util"
"github.com/sigstore/sigstore/pkg/cryptoutils"
"github.com/sigstore/sigstore/pkg/signature"
"github.com/sigstore/sigstore/pkg/signature/options"
Expand Down Expand Up @@ -82,7 +83,7 @@ func NewAPI(treeID uint) (*API, error) {
tid := int64(treeID)
if tid == 0 {
log.Logger.Info("No tree ID specified, attempting to create a new tree")
t, err := createAndInitTree(ctx, logAdminClient, logClient)
t, err := util.CreateAndInitTree(ctx, logAdminClient, logClient)
if err != nil {
return nil, fmt.Errorf("create and init tree: %w", err)
}
Expand Down
58 changes: 29 additions & 29 deletions pkg/api/entries.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func signEntry(ctx context.Context, signer signature.Signer, entry models.LogEnt
}

// logEntryFromLeaf creates a signed LogEntry struct from trillian structs
func logEntryFromLeaf(ctx context.Context, signer signature.Signer, tc TrillianClient, leaf *trillian.LogLeaf,
func logEntryFromLeaf(ctx context.Context, signer signature.Signer, tc util.TrillianClient, leaf *trillian.LogLeaf,
signedLogRoot *trillian.SignedLogRoot, proof *trillian.Proof, tid int64, ranges sharding.LogRanges) (models.LogEntry, error) {

log.ContextLogger(ctx).Debugf("log entry from leaf %d", leaf.GetLeafIndex())
Expand All @@ -93,7 +93,7 @@ func logEntryFromLeaf(ctx context.Context, signer signature.Signer, tc TrillianC
return nil, fmt.Errorf("signing entry error: %w", err)
}

scBytes, err := util.CreateAndSignCheckpoint(ctx, viper.GetString("rekor_server.hostname"), tc.logID, root, api.signer)
scBytes, err := util.CreateAndSignCheckpoint(ctx, viper.GetString("rekor_server.hostname"), tid, root, api.signer)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -186,16 +186,16 @@ func createLogEntry(params entries.CreateLogEntryParams) (models.LogEntry, middl
return nil, handleRekorAPIError(params, http.StatusInternalServerError, err, failedToGenerateCanonicalEntry)
}

tc := NewTrillianClient(ctx)
tc := util.NewTrillianClient(ctx, api.logClient, api.logID)

resp := tc.addLeaf(leaf)
resp := tc.AddLeaf(leaf)
// this represents overall GRPC response state (not the results of insertion into the log)
if resp.status != codes.OK {
return nil, handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.err), trillianUnexpectedResult)
if resp.Status != codes.OK {
return nil, handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.Err), trillianUnexpectedResult)
}

// this represents the results of inserting the proposed leaf into the log; status is nil in success path
insertionStatus := resp.getAddResult.QueuedLeaf.Status
insertionStatus := resp.GetAddResult.QueuedLeaf.Status
if insertionStatus != nil {
switch insertionStatus.Code {
case int32(code.Code_OK):
Expand All @@ -212,10 +212,10 @@ func createLogEntry(params entries.CreateLogEntryParams) (models.LogEntry, middl
// We made it this far, that means the entry was successfully added.
metricNewEntries.Inc()

queuedLeaf := resp.getAddResult.QueuedLeaf.Leaf
queuedLeaf := resp.GetAddResult.QueuedLeaf.Leaf

uuid := hex.EncodeToString(queuedLeaf.GetMerkleLeafHash())
activeTree := fmt.Sprintf("%x", tc.logID)
activeTree := fmt.Sprintf("%x", api.logID)
entryIDstruct, err := sharding.CreateEntryIDFromParts(activeTree, uuid)
if err != nil {
err := fmt.Errorf("error creating EntryID from active treeID %v and uuid %v: %w", activeTree, uuid, err)
Expand Down Expand Up @@ -271,15 +271,15 @@ func createLogEntry(params entries.CreateLogEntryParams) (models.LogEntry, middl
}

root := &ttypes.LogRootV1{}
if err := root.UnmarshalBinary(resp.getLeafAndProofResult.SignedLogRoot.LogRoot); err != nil {
if err := root.UnmarshalBinary(resp.GetLeafAndProofResult.SignedLogRoot.LogRoot); err != nil {
return nil, handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("error unmarshalling log root: %v", err), sthGenerateError)
}
hashes := []string{}
for _, hash := range resp.getLeafAndProofResult.Proof.Hashes {
for _, hash := range resp.GetLeafAndProofResult.Proof.Hashes {
hashes = append(hashes, hex.EncodeToString(hash))
}

scBytes, err := util.CreateAndSignCheckpoint(ctx, viper.GetString("rekor_server.hostname"), tc.logID, root, api.signer)
scBytes, err := util.CreateAndSignCheckpoint(ctx, viper.GetString("rekor_server.hostname"), api.logID, root, api.signer)
if err != nil {
return nil, handleRekorAPIError(params, http.StatusInternalServerError, err, sthGenerateError)
}
Expand Down Expand Up @@ -405,22 +405,22 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo
for i, hash := range searchHashes {
var results map[int64]*trillian.GetEntryAndProofResponse
for _, shard := range api.logRanges.AllShards() {
tcs := NewTrillianClientFromTreeID(httpReqCtx, shard)
resp := tcs.getLeafAndProofByHash(hash)
switch resp.status {
tcs := util.NewTrillianClient(httpReqCtx, api.logClient, shard)
resp := tcs.GetLeafAndProofByHash(hash)
switch resp.Status {
case codes.OK:
leafResult := resp.getLeafAndProofResult
leafResult := resp.GetLeafAndProofResult
if leafResult != nil && leafResult.Leaf != nil {
if results == nil {
results = map[int64]*trillian.GetEntryAndProofResponse{}
}
results[shard] = resp.getLeafAndProofResult
results[shard] = resp.GetLeafAndProofResult
}
case codes.NotFound:
// do nothing here, do not throw 404 error
continue
default:
return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("error getLeafAndProofByHash(%s): code: %v, msg %v", hex.EncodeToString(hash), resp.status, resp.err), trillianCommunicationError)
return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("error getLeafAndProofByHash(%s): code: %v, msg %v", hex.EncodeToString(hash), resp.Status, resp.Err), trillianCommunicationError)
}
}
searchByHashResults[i] = results
Expand All @@ -431,7 +431,7 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo
if leafResp == nil {
continue
}
tcs := NewTrillianClientFromTreeID(httpReqCtx, shard)
tcs := util.NewTrillianClient(httpReqCtx, api.logClient, shard)
logEntry, err := logEntryFromLeaf(httpReqCtx, api.signer, tcs, leafResp.Leaf, leafResp.SignedLogRoot, leafResp.Proof, shard, api.logRanges)
if err != nil {
return handleRekorAPIError(params, http.StatusInternalServerError, err, err.Error())
Expand Down Expand Up @@ -461,19 +461,19 @@ func retrieveLogEntryByIndex(ctx context.Context, logIndex int) (models.LogEntry
log.ContextLogger(ctx).Infof("Retrieving log entry by index %d", logIndex)

tid, resolvedIndex := api.logRanges.ResolveVirtualIndex(logIndex)
tc := NewTrillianClientFromTreeID(ctx, tid)
tc := util.NewTrillianClient(ctx, api.logClient, tid)
log.ContextLogger(ctx).Debugf("Retrieving resolved index %v from TreeID %v", resolvedIndex, tid)

resp := tc.getLeafAndProofByIndex(resolvedIndex)
switch resp.status {
resp := tc.GetLeafAndProofByIndex(resolvedIndex)
switch resp.Status {
case codes.OK:
case codes.NotFound, codes.OutOfRange, codes.InvalidArgument:
return models.LogEntry{}, ErrNotFound
default:
return models.LogEntry{}, fmt.Errorf("grpc err: %w: %s", resp.err, trillianCommunicationError)
return models.LogEntry{}, fmt.Errorf("grpc err: %w: %s", resp.Err, trillianCommunicationError)
}

result := resp.getLeafAndProofResult
result := resp.GetLeafAndProofResult
leaf := result.Leaf
if leaf == nil {
return models.LogEntry{}, ErrNotFound
Expand Down Expand Up @@ -525,13 +525,13 @@ func retrieveUUIDFromTree(ctx context.Context, uuid string, tid int64) (models.L
return models.LogEntry{}, types.ValidationError(err)
}

tc := NewTrillianClientFromTreeID(ctx, tid)
tc := util.NewTrillianClient(ctx, api.logClient, tid)
log.ContextLogger(ctx).Debugf("Attempting to retrieve UUID %v from TreeID %v", uuid, tid)

resp := tc.getLeafAndProofByHash(hashValue)
switch resp.status {
resp := tc.GetLeafAndProofByHash(hashValue)
switch resp.Status {
case codes.OK:
result := resp.getLeafAndProofResult
result := resp.GetLeafAndProofResult
leaf := result.Leaf
if leaf == nil {
return models.LogEntry{}, ErrNotFound
Expand All @@ -546,7 +546,7 @@ func retrieveUUIDFromTree(ctx context.Context, uuid string, tid int64) (models.L
case codes.NotFound:
return models.LogEntry{}, ErrNotFound
default:
log.ContextLogger(ctx).Errorf("Unexpected response code while attempting to retrieve UUID %v from TreeID %v: %v", uuid, tid, resp.status)
log.ContextLogger(ctx).Errorf("Unexpected response code while attempting to retrieve UUID %v from TreeID %v: %v", uuid, tid, resp.Status)
return models.LogEntry{}, errors.New("unexpected error")
}
}
Expand Down
4 changes: 1 addition & 3 deletions pkg/api/public_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ import (
)

func GetPublicKeyHandler(params pubkey.GetPublicKeyParams) middleware.Responder {
ctx := params.HTTPRequest.Context()
treeID := swag.StringValue(params.TreeID)
tc := NewTrillianClient(ctx)
pk, err := tc.ranges.PublicKey(api.pubkey, treeID)
pk, err := api.logRanges.PublicKey(api.pubkey, treeID)
if err != nil {
return handleRekorAPIError(params, http.StatusBadRequest, err, "")
}
Expand Down
40 changes: 20 additions & 20 deletions pkg/api/tlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ import (

// GetLogInfoHandler returns the current size of the tree and the STH
func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
tc := NewTrillianClient(params.HTTPRequest.Context())
tc := util.NewTrillianClient(params.HTTPRequest.Context(), api.logClient, api.logID)

// for each inactive shard, get the loginfo
var inactiveShards []*models.InactiveShardLogInfo
for _, shard := range tc.ranges.GetInactive() {
if shard.TreeID == tc.ranges.ActiveTreeID() {
for _, shard := range api.logRanges.GetInactive() {
if shard.TreeID == api.logRanges.ActiveTreeID() {
break
}
// Get details for this inactive shard
Expand All @@ -52,11 +52,11 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
inactiveShards = append(inactiveShards, is)
}

resp := tc.getLatest(0)
if resp.status != codes.OK {
return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.err), trillianCommunicationError)
resp := tc.GetLatest(0)
if resp.Status != codes.OK {
return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.Err), trillianCommunicationError)
}
result := resp.getLatestResult
result := resp.GetLatestResult

root := &types.LogRootV1{}
if err := root.UnmarshalBinary(result.SignedLogRoot.LogRoot); err != nil {
Expand All @@ -67,7 +67,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
treeSize := int64(root.TreeSize)

scBytes, err := util.CreateAndSignCheckpoint(params.HTTPRequest.Context(),
viper.GetString("rekor_server.hostname"), tc.ranges.ActiveTreeID(), root, api.signer)
viper.GetString("rekor_server.hostname"), api.logRanges.ActiveTreeID(), root, api.signer)
if err != nil {
return handleRekorAPIError(params, http.StatusInternalServerError, err, sthGenerateError)
}
Expand All @@ -76,7 +76,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
RootHash: &hashString,
TreeSize: &treeSize,
SignedTreeHead: stringPointer(string(scBytes)),
TreeID: stringPointer(fmt.Sprintf("%d", tc.logID)),
TreeID: stringPointer(fmt.Sprintf("%d", api.logID)),
InactiveShards: inactiveShards,
}

Expand All @@ -92,21 +92,21 @@ func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder {
if *params.FirstSize > params.LastSize {
return handleRekorAPIError(params, http.StatusBadRequest, nil, fmt.Sprintf(firstSizeLessThanLastSize, *params.FirstSize, params.LastSize))
}
tc := NewTrillianClient(params.HTTPRequest.Context())
tc := util.NewTrillianClient(params.HTTPRequest.Context(), api.logClient, api.logID)
if treeID := swag.StringValue(params.TreeID); treeID != "" {
id, err := strconv.Atoi(treeID)
if err != nil {
log.Logger.Infof("Unable to convert %s to string, skipping initializing client with Tree ID: %v", treeID, err)
} else {
tc = NewTrillianClientFromTreeID(params.HTTPRequest.Context(), int64(id))
tc = util.NewTrillianClient(params.HTTPRequest.Context(), api.logClient, int64(id))
}
}

resp := tc.getConsistencyProof(*params.FirstSize, params.LastSize)
if resp.status != codes.OK {
return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.err), trillianCommunicationError)
resp := tc.GetConsistencyProof(*params.FirstSize, params.LastSize)
if resp.Status != codes.OK {
return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.Err), trillianCommunicationError)
}
result := resp.getConsistencyProofResult
result := resp.GetConsistencyProofResult

var root types.LogRootV1
if err := root.UnmarshalBinary(result.SignedLogRoot.LogRoot); err != nil {
Expand Down Expand Up @@ -136,12 +136,12 @@ func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder {
}

func inactiveShardLogInfo(ctx context.Context, tid int64) (*models.InactiveShardLogInfo, error) {
tc := NewTrillianClientFromTreeID(ctx, tid)
resp := tc.getLatest(0)
if resp.status != codes.OK {
return nil, fmt.Errorf("resp code is %d", resp.status)
tc := util.NewTrillianClient(ctx, api.logClient, tid)
resp := tc.GetLatest(0)
if resp.Status != codes.OK {
return nil, fmt.Errorf("resp code is %d", resp.Status)
}
result := resp.getLatestResult
result := resp.GetLatestResult

root := &types.LogRootV1{}
if err := root.UnmarshalBinary(result.SignedLogRoot.LogRoot); err != nil {
Expand Down
Loading

0 comments on commit 46ac0b2

Please sign in to comment.