|
| 1 | +from tpm2_pytss import * |
| 2 | + |
| 3 | +from cloud_auth_tpm.base import BaseCredential |
| 4 | +from tpm2_pytss.tsskey import TSSPrivKey |
| 5 | + |
| 6 | +import json |
| 7 | +from datetime import datetime |
| 8 | +import hmac |
| 9 | + |
| 10 | +import datetime |
| 11 | +import hashlib |
| 12 | +import requests # pip install requests |
| 13 | + |
| 14 | +from boto3 import Session |
| 15 | +from botocore.credentials import RefreshableCredentials, CredentialProvider |
| 16 | +from botocore.session import get_session |
| 17 | + |
| 18 | + |
| 19 | +class AWSHMACCredentials(CredentialProvider): |
| 20 | + |
| 21 | + METHOD = 'tpm-hmac-credential' |
| 22 | + |
| 23 | + def __init__( |
| 24 | + self, |
| 25 | + tcti=None, |
| 26 | + keyfile=None, |
| 27 | + ownerpassword=None, |
| 28 | + password=None, |
| 29 | + policy_impl=None, |
| 30 | + |
| 31 | + region=None, |
| 32 | + duration_seconds=3600, |
| 33 | + |
| 34 | + access_key=None, |
| 35 | + assume_role_arn=None, |
| 36 | + role_session_name=None, |
| 37 | + tags=None, |
| 38 | + |
| 39 | + get_session_token=False, |
| 40 | + **kwargs: Any |
| 41 | + ): |
| 42 | + CredentialProvider.__init__(self) |
| 43 | + |
| 44 | + self._tcti = tcti |
| 45 | + self._keyfile = keyfile |
| 46 | + |
| 47 | + self._ownerpassword = ownerpassword |
| 48 | + self._password = password |
| 49 | + self._policy_impl = policy_impl |
| 50 | + |
| 51 | + self._region = region |
| 52 | + self._duration_seconds = duration_seconds |
| 53 | + |
| 54 | + self._access_key = access_key |
| 55 | + self._assume_role_arn = assume_role_arn |
| 56 | + self._role_session_name = role_session_name |
| 57 | + self._tags = tags |
| 58 | + |
| 59 | + self._get_session_token = get_session_token |
| 60 | + |
| 61 | + if self._assume_role_arn == '' and self._get_session_token == False: |
| 62 | + raise Exception("Error : {}".format( |
| 63 | + "if get_session_token is not set, _assume_role_arn values must must be specified")) |
| 64 | + |
| 65 | + self._long_running_session = None |
| 66 | + |
| 67 | + if self._tcti == '' or self._region == '' or self._keyfile == '' or self._region == '' or self._access_key == '': |
| 68 | + raise Exception("Error : {}".format( |
| 69 | + "tcti, region _keyfile, _role_arn access_key and profile_arn must be specified")) |
| 70 | + |
| 71 | + # Load public keyfile |
| 72 | + with open(keyfile, 'r') as f: |
| 73 | + self._keyfile = f.read() |
| 74 | + f.close() |
| 75 | + |
| 76 | + # ref: https://dev.to/li_chastina/auto-refresh-aws-tokens-using-iam-role-and-boto3-2cjf |
| 77 | + session = get_session() |
| 78 | + session_credentials = RefreshableCredentials.create_from_metadata(metadata=self._refresh(), |
| 79 | + refresh_using=self._refresh, |
| 80 | + method=self.METHOD) |
| 81 | + session._credentials = session_credentials |
| 82 | + session.set_config_variable('region', self._region) |
| 83 | + self._long_running_session = Session(botocore_session=session) |
| 84 | + |
| 85 | + def _sign(self, key, msg): |
| 86 | + return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() |
| 87 | + |
| 88 | + def _getSignatureKey(self, key, dateStamp, regionName, serviceName): |
| 89 | + |
| 90 | + # instead of the first hmac operation using the AWS Secret, use the TPM based key |
| 91 | + # kDate = sign(("AWS4" + key).encode("utf-8"), dateStamp) |
| 92 | + ectx = ESAPI(tcti=self._tcti) |
| 93 | + ectx.startup(TPM2_SU.CLEAR) |
| 94 | + |
| 95 | + k = TSSPrivKey.from_pem(self._keyfile.encode('utf-8')) |
| 96 | + inSensitiveOwner = TPM2B_SENSITIVE_CREATE(TPMS_SENSITIVE_CREATE()) |
| 97 | + |
| 98 | + if self._ownerpassword != None: |
| 99 | + inSensitiveOwner = TPM2B_SENSITIVE_CREATE( |
| 100 | + TPMS_SENSITIVE_CREATE(userAuth=TPM2B_AUTH(self._ownerpassword))) |
| 101 | + |
| 102 | + primary1, _, _, _, _ = ectx.create_primary( |
| 103 | + inSensitiveOwner, TPM2B_PUBLIC(publicArea=BaseCredential._parent_ecc_template)) |
| 104 | + |
| 105 | + hkeyLoaded = ectx.load(primary1, k.private, k.public) |
| 106 | + ectx.flush_context(primary1) |
| 107 | + |
| 108 | + if self._password != None: |
| 109 | + ectx.tr_set_auth(hkeyLoaded, self._password) |
| 110 | + |
| 111 | + if self._policy_impl == None: |
| 112 | + thmac = ectx.hmac(hkeyLoaded, dateStamp, TPM2_ALG.SHA256) |
| 113 | + else: |
| 114 | + sess = self._policy_impl.policy_callback(ectx=ectx) |
| 115 | + thmac = ectx.hmac(hkeyLoaded, dateStamp, |
| 116 | + TPM2_ALG.SHA256, session1=sess) |
| 117 | + ectx.flush_context(sess) |
| 118 | + |
| 119 | + ectx.flush_context(hkeyLoaded) |
| 120 | + kDate = thmac.__bytes__() |
| 121 | + ectx.close() |
| 122 | + |
| 123 | + kRegion = self._sign(kDate, regionName) |
| 124 | + kService = self._sign(kRegion, serviceName) |
| 125 | + kSigning = self._sign(kService, "aws4_request") |
| 126 | + return kSigning |
| 127 | + |
| 128 | + def _refresh(self): |
| 129 | + try: |
| 130 | + # Request parameters |
| 131 | + method = 'POST' |
| 132 | + service = 'sts' |
| 133 | + host = "sts.amazonaws.com" |
| 134 | + |
| 135 | + endpoint = '/' |
| 136 | + |
| 137 | + # Create a datetime object for signing |
| 138 | + t = datetime.datetime.utcnow() |
| 139 | + amzdate = t.strftime('%Y%m%dT%H%M%SZ') |
| 140 | + datestamp = t.strftime('%Y%m%d') |
| 141 | + |
| 142 | + # Create the canonical request |
| 143 | + canonical_uri = endpoint |
| 144 | + |
| 145 | + canonical_querystring = '' |
| 146 | + canonical_headers = 'content-type:application/x-www-form-urlencoded' + \ |
| 147 | + '\n' + 'host:' + host + '\n' + 'x-amz-date:' + amzdate + '\n' |
| 148 | + |
| 149 | + signed_headers = 'content-type;host;x-amz-date' |
| 150 | + |
| 151 | + # https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html |
| 152 | + # payload = 'Action=GetCallerIdentity&Version=2011-06-15' |
| 153 | + if self._get_session_token: |
| 154 | + payload = 'Action=GetSessionToken&DurationSeconds={}&Version=2011-06-15'.format( |
| 155 | + self._duration_seconds) |
| 156 | + else: |
| 157 | + payload = 'Action=AssumeRole&DurationSeconds={}&RoleSessionName={}&RoleArn={}&Version=2011-06-15'.format( |
| 158 | + self._duration_seconds, self._region, self._assume_role_arn) |
| 159 | + |
| 160 | + payload_hash = hashlib.sha256(payload.encode('utf-8')).hexdigest() |
| 161 | + canonical_request = (method + '\n' + canonical_uri + '\n' + canonical_querystring + '\n' |
| 162 | + + canonical_headers + '\n' + signed_headers + '\n' + payload_hash) |
| 163 | + |
| 164 | + # Create the string to sign |
| 165 | + algorithm = 'AWS4-HMAC-SHA256' |
| 166 | + credential_scope = datestamp + '/' + self._region + \ |
| 167 | + '/' + service + '/' + 'aws4_request' |
| 168 | + string_to_sign = (algorithm + '\n' + amzdate + '\n' + credential_scope + '\n' + |
| 169 | + hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()) |
| 170 | + |
| 171 | + # Sign the string |
| 172 | + signing_key = self._getSignatureKey( |
| 173 | + self._access_key, datestamp, self._region, service) |
| 174 | + signature = hmac.new(signing_key, (string_to_sign).encode( |
| 175 | + 'utf-8'), hashlib.sha256).hexdigest() |
| 176 | + |
| 177 | + # Add signing information to the request |
| 178 | + authorization_header = (algorithm + ' ' + 'Credential=' + self._access_key + '/' + credential_scope + ', ' + |
| 179 | + 'SignedHeaders=' + signed_headers + ', ' + 'Signature=' + signature) |
| 180 | + # Make the request |
| 181 | + headers = {'content-type': 'application/x-www-form-urlencoded', |
| 182 | + 'host': host, |
| 183 | + 'x-amz-date': amzdate, |
| 184 | + 'Accept': 'application/json', |
| 185 | + 'Authorization': authorization_header} |
| 186 | + request_url = 'https://' + host + canonical_uri |
| 187 | + |
| 188 | + r = requests.post(request_url, headers=headers, |
| 189 | + data=payload, allow_redirects=False, timeout=5) |
| 190 | + |
| 191 | + if r.status_code != 200: |
| 192 | + raise Exception("Error status code " + |
| 193 | + str(r.status_code) + " " + r.text) |
| 194 | + |
| 195 | + json_data = json.loads(r.text) |
| 196 | + |
| 197 | + if self._get_session_token: |
| 198 | + if len(json_data['GetSessionTokenResponse']) == 0: |
| 199 | + raise Exception( |
| 200 | + "invalid response, no GetSessionTokenResponse ") |
| 201 | + c = json_data['GetSessionTokenResponse']['GetSessionTokenResult']['Credentials'] |
| 202 | + else: |
| 203 | + if len(json_data['AssumeRoleResponse']) == 0: |
| 204 | + raise Exception("invalid response, no AssumeRoleResponse ") |
| 205 | + c = json_data['AssumeRoleResponse']['AssumeRoleResult']['Credentials'] |
| 206 | + |
| 207 | + datetime_object = datetime.datetime.utcfromtimestamp( |
| 208 | + int(c['Expiration'])) |
| 209 | + |
| 210 | + metadata = { |
| 211 | + 'access_key': c['AccessKeyId'], |
| 212 | + 'secret_key': c['SecretAccessKey'], |
| 213 | + 'token': c['SessionToken'], |
| 214 | + 'expiry_time': datetime_object.replace(tzinfo=datetime.UTC).isoformat() |
| 215 | + } |
| 216 | + |
| 217 | + return metadata |
| 218 | + except Exception as e: |
| 219 | + raise Exception("Error : {}".format(e)) |
| 220 | + |
| 221 | + def get_session(self): |
| 222 | + return self._long_running_session |
0 commit comments