Skip to content

Commit

Permalink
Strict type Dependabot::Clients::Azure (dependabot#9042)
Browse files Browse the repository at this point in the history
  • Loading branch information
JamieMagee authored Feb 13, 2024
1 parent 9fc744e commit 4784ec8
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 53 deletions.
136 changes: 112 additions & 24 deletions common/lib/dependabot/clients/azure.rb
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# typed: true
# typed: strict
# frozen_string_literal: true

require "dependabot/shared_helpers"
Expand All @@ -7,6 +7,7 @@

module Dependabot
module Clients
# rubocop:disable Metrics/ClassLength
class Azure
extend T::Sig

Expand All @@ -24,12 +25,16 @@ class Forbidden < StandardError; end

class TagsCreationForbidden < StandardError; end

RETRYABLE_ERRORS = [InternalServerError, BadGateway, ServiceNotAvailable].freeze
RETRYABLE_ERRORS = T.let(
[InternalServerError, BadGateway, ServiceNotAvailable].freeze,
T::Array[T.class_of(StandardError)]
)

#######################
# Constructor methods #
#######################

sig { params(source: Dependabot::Source, credentials: T::Array[Dependabot::Credential]).returns(Azure) }
def self.for_source(source:, credentials:)
credential =
credentials
Expand All @@ -43,15 +48,24 @@ def self.for_source(source:, credentials:)
# Client #
##########

sig do
params(
source: Dependabot::Source,
credentials: T.nilable(Dependabot::Credential),
max_retries: T.nilable(Integer)
)
.void
end
def initialize(source, credentials, max_retries: 3)
@source = source
@credentials = credentials
@auth_header = auth_header_for(credentials&.fetch("token", nil))
@max_retries = max_retries || 3
@auth_header = T.let(auth_header_for(credentials&.fetch("token", nil)), T::Hash[String, String])
@max_retries = T.let(max_retries || 3, Integer)
end

sig { params(_repo: T.nilable(String), branch: String).returns(String) }
def fetch_commit(_repo, branch)
response = get(source.api_endpoint +
response = get(T.must(source.api_endpoint) +
source.organization + "/" + source.project +
"/_apis/git/repositories/" + source.unscoped_repo +
"/stats/branches?name=" + branch)
Expand All @@ -61,46 +75,56 @@ def fetch_commit(_repo, branch)
JSON.parse(response.body).fetch("commit").fetch("commitId")
end

sig { params(_repo: String).returns(String) }
def fetch_default_branch(_repo)
response = get(source.api_endpoint +
response = get(T.must(source.api_endpoint) +
source.organization + "/" + source.project +
"/_apis/git/repositories/" + source.unscoped_repo)

JSON.parse(response.body).fetch("defaultBranch").gsub("refs/heads/", "")
end

sig do
params(
commit: T.nilable(String),
path: T.nilable(String)
)
.returns(T::Array[T::Hash[String, T.untyped]])
end
def fetch_repo_contents(commit = nil, path = nil)
tree = fetch_repo_contents_treeroot(commit, path)

response = get(source.api_endpoint +
response = get(T.must(source.api_endpoint) +
source.organization + "/" + source.project +
"/_apis/git/repositories/" + source.unscoped_repo +
"/trees/" + tree + "?recursive=false")

JSON.parse(response.body).fetch("treeEntries")
end

sig { params(commit: T.nilable(String), path: T.nilable(String)).returns(String) }
def fetch_repo_contents_treeroot(commit = nil, path = nil)
actual_path = path
actual_path = "/" if path.to_s.empty?

tree_url = source.api_endpoint +
tree_url = T.must(source.api_endpoint) +
source.organization + "/" + source.project +
"/_apis/git/repositories/" + source.unscoped_repo +
"/items?path=" + actual_path
"/items?path=" + T.must(actual_path)

unless commit.to_s.empty?
tree_url += "&versionDescriptor.versionType=commit" \
"&versionDescriptor.version=" + commit
"&versionDescriptor.version=" + T.must(commit)
end

tree_response = get(tree_url)

JSON.parse(tree_response.body).fetch("objectId")
end

sig { params(commit: String, path: String).returns(String) }
def fetch_file_contents(commit, path)
response = get(source.api_endpoint +
response = get(T.must(source.api_endpoint) +
source.organization + "/" + source.project +
"/_apis/git/repositories/" + source.unscoped_repo +
"/items?path=" + path +
Expand All @@ -110,30 +134,33 @@ def fetch_file_contents(commit, path)
response.body
end

sig { params(branch_name: T.nilable(String)).returns(T::Array[T::Hash[String, T.untyped]]) }
def commits(branch_name = nil)
commits_url = source.api_endpoint +
commits_url = T.must(source.api_endpoint) +
source.organization + "/" + source.project +
"/_apis/git/repositories/" + source.unscoped_repo +
"/commits"

commits_url += "?searchCriteria.itemVersion.version=" + branch_name unless branch_name.to_s.empty?
commits_url += "?searchCriteria.itemVersion.version=" + T.must(branch_name) unless branch_name.to_s.empty?

response = get(commits_url)

JSON.parse(response.body).fetch("value")
end

sig { params(branch_name: String).returns(T.nilable(T::Hash[String, T.untyped])) }
def branch(branch_name)
response = get(source.api_endpoint +
response = get(T.must(source.api_endpoint) +
source.organization + "/" + source.project +
"/_apis/git/repositories/" + source.unscoped_repo +
"/refs?filter=heads/" + branch_name)

JSON.parse(response.body).fetch("value").first
end

sig { params(source_branch: String, target_branch: String).returns(T::Array[T::Hash[String, T.untyped]]) }
def pull_requests(source_branch, target_branch)
response = get(source.api_endpoint +
response = get(T.must(source.api_endpoint) +
source.organization + "/" + source.project +
"/_apis/git/repositories/" + source.unscoped_repo +
"/pullrequests?searchCriteria.status=all" \
Expand All @@ -143,6 +170,16 @@ def pull_requests(source_branch, target_branch)
JSON.parse(response.body).fetch("value")
end

sig do
params(
branch_name: String,
base_commit: String,
commit_message: String,
files: T::Array[Dependabot::DependencyFile],
author_details: T.nilable(T::Hash[String, String])
)
.returns(T.untyped)
end
def create_commit(branch_name, base_commit, commit_message, files,
author_details)
content = {
Expand All @@ -158,7 +195,7 @@ def create_commit(branch_name, base_commit, commit_message, files,
changeType: "edit",
item: { path: file.path },
newContent: {
content: Base64.encode64(file.content),
content: Base64.encode64(T.must(file.content)),
contentType: "base64encoded"
}
}
Expand All @@ -167,12 +204,25 @@ def create_commit(branch_name, base_commit, commit_message, files,
]
}

post(source.api_endpoint + source.organization + "/" + source.project +
post(T.must(source.api_endpoint) + source.organization + "/" + source.project +
"/_apis/git/repositories/" + source.unscoped_repo +
"/pushes?api-version=5.0", content.to_json)
end

# rubocop:disable Metrics/ParameterLists
sig do
params(
pr_name: String,
source_branch: String,
target_branch: String,
pr_description: String,
labels: T::Array[String],
reviewers: T.nilable(T::Array[String]),
assignees: T.nilable(T::Array[String]),
work_item: T.nilable(Integer)
)
.returns(T.untyped)
end
def create_pull_request(pr_name, source_branch, target_branch,
pr_description, labels,
reviewers = nil, assignees = nil, work_item = nil)
Expand All @@ -187,12 +237,25 @@ def create_pull_request(pr_name, source_branch, target_branch,
workItemRefs: [{ id: work_item }]
}

post(source.api_endpoint +
post(T.must(source.api_endpoint) +
source.organization + "/" + source.project +
"/_apis/git/repositories/" + source.unscoped_repo +
"/pullrequests?api-version=5.0", content.to_json)
end

sig do
params(
pull_request_id: Integer,
auto_complete_set_by: String,
merge_commit_message: String,
delete_source_branch: T::Boolean,
squash_merge: T::Boolean,
merge_strategy: String,
trans_work_items: T::Boolean,
ignore_config_ids: T::Array[String]
)
.returns(T.untyped)
end
def autocomplete_pull_request(pull_request_id, auto_complete_set_by, merge_commit_message,
delete_source_branch = true, squash_merge = true, merge_strategy = "squash",
trans_work_items = true, ignore_config_ids = [])
Expand All @@ -211,22 +274,24 @@ def autocomplete_pull_request(pull_request_id, auto_complete_set_by, merge_commi
}
}

response = patch(source.api_endpoint +
response = patch(T.must(source.api_endpoint) +
source.organization + "/" + source.project +
"/_apis/git/repositories/" + source.unscoped_repo +
"/pullrequests/" + pull_request_id.to_s + "?api-version=5.1", content.to_json)

JSON.parse(response.body)
end

sig { params(pull_request_id: String).returns(T::Hash[String, T.untyped]) }
def pull_request(pull_request_id)
response = get(source.api_endpoint +
response = get(T.must(source.api_endpoint) +
source.organization + "/" + source.project +
"/_apis/git/pullrequests/" + pull_request_id)

JSON.parse(response.body)
end

sig { params(branch_name: String, old_commit: String, new_commit: String).returns(T::Hash[String, T.untyped]) }
def update_ref(branch_name, old_commit, new_commit)
content = [
{
Expand All @@ -236,16 +301,23 @@ def update_ref(branch_name, old_commit, new_commit)
}
]

response = post(source.api_endpoint + source.organization + "/" + source.project +
response = post(T.must(source.api_endpoint) + source.organization + "/" + source.project +
"/_apis/git/repositories/" + source.unscoped_repo +
"/refs?api-version=5.0", content.to_json)

JSON.parse(response.body).fetch("value").first
end
# rubocop:enable Metrics/ParameterLists

sig do
params(
previous_tag: T.nilable(String), new_tag: T.nilable(String),
type: String
)
.returns(T::Array[T::Hash[String, T.untyped]])
end
def compare(previous_tag, new_tag, type)
response = get(source.api_endpoint +
response = get(T.must(source.api_endpoint) +
source.organization + "/" + source.project +
"/_apis/git/repositories/" + source.unscoped_repo +
"/commits?searchCriteria.itemVersion.versionType=#{type}" \
Expand Down Expand Up @@ -311,7 +383,7 @@ def post(url, json) # rubocop:disable Metrics/PerceivedComplexity
raise Unauthorized if response&.status == 401

if response&.status == 403
raise TagsCreationForbidden if tags_creation_forbidden?(response)
raise TagsCreationForbidden if tags_creation_forbidden?(T.must(response))

raise Forbidden
end
Expand Down Expand Up @@ -354,7 +426,8 @@ def patch(url, json)

private

def retry_connection_failures
sig { params(blk: T.proc.void).void }
def retry_connection_failures(&blk) # rubocop:disable Lint/UnusedMethodArgument
retry_attempt = 0

begin
Expand All @@ -365,6 +438,7 @@ def retry_connection_failures
end
end

sig { params(token: T.nilable(String)).returns(T::Hash[String, String]) }
def auth_header_for(token)
return {} unless token

Expand All @@ -379,23 +453,37 @@ def auth_header_for(token)
end
end

sig { params(response: Excon::Response).returns(T::Boolean) }
def tags_creation_forbidden?(response)
return false if response.body.empty?

message = JSON.parse(response.body).fetch("message", nil)
message&.include?("TF401289")
end

sig do
params(
reviewers: T.nilable(T::Array[String]),
assignees: T.nilable(T::Array[String])
)
.returns(T::Array[T::Hash[Symbol, T.untyped]])
end
def pr_reviewers(reviewers, assignees)
return [] unless reviewers || assignees

pr_reviewers = reviewers&.map { |r_id| { id: r_id, isRequired: true } } || []
pr_reviewers + (assignees&.map { |r_id| { id: r_id, isRequired: false } } || [])
end

sig { returns(T::Hash[String, String]) }
attr_reader :auth_header

sig { returns(T.nilable(Dependabot::Credential)) }
attr_reader :credentials

sig { returns(Dependabot::Source) }
attr_reader :source
end
# rubocop:enable Metrics/ClassLength
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def github_client
def azure_client
@azure_client ||=
T.let(
Dependabot::Clients::Azure.for_source(source: source, credentials: credentials),
Dependabot::Clients::Azure.for_source(source: T.must(source), credentials: credentials),
T.nilable(Dependabot::Clients::Azure)
)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def github_client
def azure_client
@azure_client ||=
T.let(
Dependabot::Clients::Azure.for_source(source: source, credentials: credentials),
Dependabot::Clients::Azure.for_source(source: T.must(source), credentials: credentials),
T.nilable(Dependabot::Clients::Azure)
)
end
Expand Down
2 changes: 1 addition & 1 deletion common/lib/dependabot/pull_request_creator/azure.rb
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def azure_client_for_source

def branch_exists?
azure_client_for_source.branch(branch_name)
rescue ::Azure::Error::NotFound
rescue ::Dependabot::Clients::Azure::NotFound
false
end

Expand Down
Loading

0 comments on commit 4784ec8

Please sign in to comment.