-
Notifications
You must be signed in to change notification settings - Fork 0
/
sfdc_jwt_demo.py
147 lines (112 loc) · 4.19 KB
/
sfdc_jwt_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
'''
Module to play around with sending JWT token info to Salesforce and receive an access token back
'''
from base64 import urlsafe_b64encode
from datetime import datetime, timedelta
from urllib.parse import unquote
import requests
from flask import (
Flask,
render_template,
jsonify,
)
import os
from cfenv import AppEnv
from Crypto.Hash import SHA256
from Crypto.PublicKey import RSA
from Crypto.Signature import PKCS1_v1_5
JWT_HEADER = '{"alg":"RS256"}'
JWT_CLAIM_ISS = "3MVG9ZF4bs_.MKuiyRq9J1l33OAR0jFoQbkx1am4Bzh5VDU1L5oSB500dxKTwobSPA7NuaVgl8VWwWV5tp_Vg"
JWT_CLAIM_SUB = "shgupta_dev@pivotal.io"
JWT_CLAIM_AUD = "https://login.salesforce.com"
JWT_AUTH_EP = "https://login.salesforce.com/services/oauth2/token"
app = Flask(__name__)
def jwt_claim():
'''
Function to package JWT Claim data in a base64 encoded string
:return:
base64 encoded jwt claims data
'''
claim_template = '{{"iss": "{0}", "sub": "{1}", "aud": "{2}", "exp": {3}}}'
claim = urlsafe_b64encode(JWT_HEADER.encode()).decode()
claim += "."
# expiration_ts = (datetime.now(tz=timezone.utc) + timedelta(minutes=5)).timestamp()
expiration_ts = int((datetime.now() + timedelta(seconds=300)).timestamp())
payload = claim_template.format(JWT_CLAIM_ISS, JWT_CLAIM_SUB, JWT_CLAIM_AUD, expiration_ts)
print(payload)
claim += urlsafe_b64encode(payload.encode()).decode()
return claim
def credhub_secret():
'''
Read the VCAP_SERVICES env variable & extract the "demo-certificate" value
:return: parsed credhub value as a dict
'''
cf_env = AppEnv()
credhub_env = cf_env.get_service(label="credhub").get_url("demo-certificate")
credhub_env = eval(unquote(credhub_env))
return credhub_env
def get_private_key():
'''
Returns Private key from Credhub reference
:return: private key as str
'''
credhub = credhub_secret()
return credhub["value"]["private_key"]
def get_certificate():
'''
Returns Certificate from Credhub reference
:return: certificate as str
'''
credhub = credhub_secret()
return credhub["value"]["certificate"]
def sign_data(data):
'''
param: private_key_loc Path to your private key
param: package Data to be signed
return: base64 encoded signature
'''
key = get_private_key()
rsakey = RSA.importKey(key)
signer = PKCS1_v1_5.new(rsakey)
digest = SHA256.new()
digest.update(data.encode())
sign = signer.sign(digest)
# Optionally verify
# pubkey = rsakey.publickey()
# verifier = PKCS1_v1_5.new(pubkey)
# print("Verification: {}".format(verifier.verify(digest, sign)))
return urlsafe_b64encode(sign).decode()
def do_auth(endpoint, data):
'''
Function to POST JWT claim to SFDC /oauth/token endpoint and receive an access_token
:return:
access token
'''
r = requests.post(endpoint, data=data)
return r
@app.route("/")
@app.route("/index")
def index():
# Keeping with JWS spec, we need to remove the padding "=" characters from base64 encoded string
claim = jwt_claim().replace("=", "")
# Keeping with JWS spec, we need to remove the padding "=" characters from base64 encoded string
signed_claim = sign_data(claim).replace("=", "")
target_payload = claim + "." + signed_claim
auth_payload = {"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", "assertion": target_payload}
response = do_auth(JWT_AUTH_EP, data=auth_payload)
# convert the text dictionary to data structure so it can be rendered as a json properly
response_text = eval(response.text)
response_headers = eval(str(response.headers))
return_dict = {"claim": claim, "signed_claim": signed_claim, "target_payload": target_payload,
"response_text": response_text, "response_headers": response_headers}
return jsonify(return_dict)
@app.errorhandler(404)
def page_not_found(e):
return render_template("404.html"), 404
@app.errorhandler(500)
def server_error(e):
return render_template("500.html"), 500
if __name__ == "__main__":
port = int(os.getenv("PORT", 5000))
# Run the app, listening on all IPs with our chosen port number
app.run(host="0.0.0.0", port=port, debug=True, use_reloader=True)