-
Notifications
You must be signed in to change notification settings - Fork 0
/
agent.py
124 lines (104 loc) · 3.92 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import socket
import struct
import sys
from paramiko.message import Message
from paramiko.agent import Agent
from paramiko.ed25519key import Ed25519Key
from paramiko.ssh_exception import SSHException
import subprocess
import io
class CustomSSHAgent(Agent):
AGENTC_REQUEST_IDENTITIES = bytes([11])
AGENTC_SIGN_REQUEST = bytes([13])
AGENT_IDENTITIES_ANSWER = bytes([12])
AGENT_SIGN_RESPONSE = bytes([14])
def __init__(self, keys):
self.keys = []
for key in keys:
try:
if key["type"].lower() == "ed25519":
passphrase = self.get_passphrase(key["pass_key"], key["store_location"])
private_key_str = self.get_private_key(key["key_path"])
private_key_file = io.StringIO(private_key_str)
self.keys.append(Ed25519Key.from_private_key(private_key_file, password=passphrase))
except SSHException as e:
print(f"Error loading key from {key['key_path']}: {e}")
def get_passphrase(self, pass_key, store_location):
env = os.environ.copy()
env['PASSWORD_STORE_DIR'] = store_location
result = subprocess.run(['pass', pass_key], capture_output=True, text=True, env=env)
if result.returncode != 0:
print(f"Error retrieving passphrase for key {pass_key}: {result.stderr}")
return None
return result.stdout.strip()
def get_private_key(self, key_path):
result = subprocess.run(['pass', key_path], capture_output=True, text=True)
if result.returncode != 0:
print(f"Error retrieving private key for key_path {key_path}: {result.stderr}")
return None
return result.stdout.strip()
def handle_message(self, msg):
m = Message(msg)
cmd = m.get_byte()
if cmd == CustomSSHAgent.AGENTC_REQUEST_IDENTITIES:
resp = Message()
resp.add_byte(CustomSSHAgent.AGENT_IDENTITIES_ANSWER)
resp.add_int(len(self.keys))
for key in self.keys:
resp.add_string(key.asbytes())
resp.add_string(f"{key.get_name()} key loaded from {key}")
return resp.asbytes()
if cmd == CustomSSHAgent.AGENTC_SIGN_REQUEST:
key_blob = m.get_string()
data = m.get_string()
flags = m.get_int()
resp = Message()
resp.add_byte(CustomSSHAgent.AGENT_SIGN_RESPONSE)
for key in self.keys:
if key_blob == key.asbytes():
sig = key.sign_ssh_data(data)
resp.add_string(sig)
return resp.asbytes()
failure_response = Message()
failure_response.add_byte(bytes([255])) # Custom failure code, converted to bytes
return failure_response.asbytes()
def handle_client(agent, conn, addr):
try:
while True:
msg = conn.recv(4)
if len(msg) == 0:
break
msg_len = struct.unpack(">I", msg)[0]
msg = conn.recv(msg_len)
response = agent.handle_message(msg)
conn.send(struct.pack(">I", len(response)) + response)
finally:
conn.close()
KEYS = [
{
"type": "ed25519",
"key_path": "your-key-path",
"pass_key": "your-pass-key",
"store_location": "/path/to/your/store"
}
]
if __name__ == "__main__":
agent = CustomSSHAgent(KEYS)
if len(sys.argv) != 2:
print(f"Usage: {sys.argv[0]} <socket_path>")
sys.exit(1)
socket_path = sys.argv[1]
if os.path.exists(socket_path):
os.unlink(socket_path)
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.bind(socket_path)
os.chmod(socket_path, 0o600)
try:
sock.listen(1)
while True:
conn, addr = sock.accept()
handle_client(agent, conn, addr)
finally:
sock.close()
os.unlink(socket_path)