Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #10

Merged
merged 12 commits into from
Sep 18, 2024
4 changes: 3 additions & 1 deletion asyncsnmplib/asn1.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class Class(enum.IntEnum):
TNumber = Union[Number, int]
TType = Union[Type, int]
TClass = Union[Class, int]
TOid = Tuple[int, ...]
TValue = Any


class Tag(NamedTuple):
Expand Down Expand Up @@ -546,7 +548,7 @@ def _decode_null(bytes_data: bytes) -> None:
raise Error("ASN1 syntax error")

@staticmethod
def _decode_object_identifier(bytes_data: bytes) -> tuple:
def _decode_object_identifier(bytes_data: bytes) -> TOid:
result: List[int] = []
value: int = 0
for i in range(len(bytes_data)):
Expand Down
86 changes: 43 additions & 43 deletions asyncsnmplib/client.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
import asyncio
from typing import Iterable, Optional, Tuple, List, Type
from .exceptions import (
SnmpNoConnection,
SnmpErrorNoSuchName,
SnmpTooMuchRows,
SnmpNoAuthParams,
)
from .asn1 import Tag, TOid, TValue
from .package import SnmpMessage
from .pdu import SnmpGet, SnmpGetNext, SnmpGetBulk
from .protocol import SnmpProtocol
from .v3.auth import AUTH_PROTO
from .v3.encr import PRIV_PROTO
from .v3.auth import Auth
from .v3.encr import Priv
from .v3.package import SnmpV3Message
from .v3.protocol import SnmpV3Protocol


class Snmp:
version = 1 # = v2

def __init__(self, host, port=161, community='public', max_rows=10000):
self._loop = asyncio.get_event_loop()
def __init__(
self,
host: str,
port: int = 161,
community: str = 'public',
max_rows: int = 10_000,
loop: Optional[asyncio.AbstractEventLoop] = None):
self._loop = loop if loop else asyncio.get_running_loop()
self._protocol = None
self._transport = None
self.host = host
Expand All @@ -28,7 +36,7 @@ def __init__(self, host, port=161, community='public', max_rows=10000):

# On some systems it seems to be required to set the remote_addr argument
# https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_datagram_endpoint
async def connect(self, timeout=10):
async def connect(self, timeout: float = 10.0):
try:
infos = await self._loop.getaddrinfo(self.host, self.port)
family, *_, addr = infos[0]
Expand All @@ -44,7 +52,7 @@ async def connect(self, timeout=10):
self._transport = transport

def _get(self, oids, timeout=None):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
pdu = SnmpGet(0, oids)
message = SnmpMessage.make(self.version, self.community, pdu)
Expand All @@ -54,32 +62,34 @@ def _get(self, oids, timeout=None):
return self._protocol.send(message)

def _get_next(self, oids):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
pdu = SnmpGetNext(0, oids)
message = SnmpMessage.make(self.version, self.community, pdu)
return self._protocol.send(message)

def _get_bulk(self, oids):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
pdu = SnmpGetBulk(0, oids)
message = SnmpMessage.make(self.version, self.community, pdu)
return self._protocol.send(message)

async def get(self, oid, timeout=None):
async def get(self, oid: TOid, timeout: Optional[float] = None
) -> Tuple[TOid, Tag, TValue]:
vbs = await self._get([oid], timeout)
return vbs[0]

async def get_next(self, oid):
async def get_next(self, oid: TOid) -> Tuple[TOid, Tag, TValue]:
vbs = await self._get_next([oid])
return vbs[0]

async def get_next_multi(self, oids):
async def get_next_multi(self, oids: Iterable[TOid]
) -> List[Tuple[TOid, TValue]]:
vbs = await self._get_next(oids)
return [(oid, value) for oid, _, value in vbs if oid[:-1] in oids]

async def walk(self, oid):
async def walk(self, oid: TOid) -> List[Tuple[TOid, TValue]]:
next_oid = oid
prefixlen = len(oid)
rows = []
Expand Down Expand Up @@ -115,7 +125,7 @@ def close(self):
class SnmpV1(Snmp):
version = 0

async def walk(self, oid):
async def walk(self, oid: TOid) -> List[Tuple[TOid, TValue]]:
next_oid = oid
prefixlen = len(oid)
rows = []
Expand Down Expand Up @@ -150,15 +160,14 @@ class SnmpV3(Snmp):

def __init__(
self,
host,
username,
auth_proto='USM_AUTH_NONE',
auth_passwd=None,
priv_proto='USM_PRIV_NONE',
priv_passwd=None,
port=161,
max_rows=10000):
self._loop = asyncio.get_event_loop()
host: str,
username: str,
auth: Optional[Tuple[Type[Auth], str]] = None,
priv: Optional[Tuple[Type[Priv], str]] = None,
port: int = 161,
max_rows: int = 10_000,
loop: Optional[asyncio.AbstractEventLoop] = None):
self._loop = loop if loop else asyncio.get_running_loop()
self._protocol = None
self._transport = None
self.host = host
Expand All @@ -170,28 +179,16 @@ def __init__(
self._auth_hash_localized = None
self._priv_hash = None
self._priv_hash_localized = None
try:
self._auth_proto = AUTH_PROTO[auth_proto]
except KeyError:
raise Exception('Supply valid auth_proto')
try:
self._priv_proto = PRIV_PROTO[priv_proto]
except KeyError:
raise Exception('Supply valid priv_proto')
if self._priv_proto and not self._auth_proto:
raise Exception('Supply auth_proto')
if self._auth_proto:
if auth_passwd is None:
raise Exception('Supply auth_passwd')
if auth is not None:
self._auth_proto, auth_passwd = auth
self._auth_hash = self._auth_proto.hash_passphrase(auth_passwd)
if self._priv_proto:
if priv_passwd is None:
raise Exception('Supply priv_passwd')
self._priv_hash = self._auth_proto.hash_passphrase(priv_passwd)
if priv is not None:
self._priv_proto, priv_passwd = priv
self._priv_hash = self._auth_proto.hash_passphrase(priv_passwd)

# On some systems it seems to be required to set the remote_addr argument
# https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_datagram_endpoint
async def connect(self, timeout=10):
async def connect(self, timeout: float = 10.0):
try:
infos = await self._loop.getaddrinfo(self.host, self.port)
family, *_, addr = infos[0]
Expand All @@ -211,6 +208,9 @@ async def connect(self, timeout=10):
raise SnmpNoAuthParams

async def _get_auth_params(self, timeout=10):
# TODO for long requests this will need to be refreshed
# https://datatracker.ietf.org/doc/html/rfc3414#section-2.2.3
assert self._protocol is not None
pdu = SnmpGet(0, [])
message = SnmpV3Message.make(pdu, [b'', 0, 0, b'', b'', b''])
# this request will not retry like the other requests
Expand All @@ -225,7 +225,7 @@ async def _get_auth_params(self, timeout=10):
if self._priv_proto else None

def _get(self, oids, timeout=None):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
elif self._auth_params is None:
raise SnmpNoAuthParams
Expand All @@ -248,7 +248,7 @@ def _get(self, oids, timeout=None):
self._priv_hash_localized)

def _get_next(self, oids):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
elif self._auth_params is None:
raise SnmpNoAuthParams
Expand All @@ -262,7 +262,7 @@ def _get_next(self, oids):
self._priv_hash_localized)

def _get_bulk(self, oids):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
elif self._auth_params is None:
raise SnmpNoAuthParams
Expand Down
1 change: 0 additions & 1 deletion asyncsnmplib/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
__all__ = (
"SnmpTimeoutError",
"SnmpUnsupportedValueType",
"SnmpErrorTooBig",
"SnmpErrorNoSuchName",
"SnmpErrorBadValue",
Expand Down
4 changes: 4 additions & 0 deletions asyncsnmplib/mib/mib.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def on_mib(mi: dict, mibname: str, mib: dict, lk_definitions: dict):
obj['syntax'] = obj['syntax']['syntax']

lk_definitions[name] = obj
elif obj['tp'] == 'TRAP-TYPE':
lk_definitions[name] = obj

for name, obj in mib.items():
if 'value' in obj:
Expand All @@ -90,6 +92,8 @@ def on_mib(mi: dict, mibname: str, mib: dict, lk_definitions: dict):
names[name] = obj
elif obj['tp'] == 'OBJECT-GROUP':
names[name] = obj
elif obj['tp'] == 'NOTIFICATION-TYPE':
names[name] = obj

for name, obj in names.items():
other_name = name
Expand Down
21 changes: 14 additions & 7 deletions asyncsnmplib/mib/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Tuple, Union
from typing import Tuple, Union, List
from ..asn1 import TOid, TValue
from .mib_index import MIB_INDEX
from .syntax_funs import SYNTAX_FUNS

Expand All @@ -8,7 +9,7 @@
FLAGS_SEPERATOR = ','


def on_octet_string(value: bytes) -> str:
def on_octet_string(value: TValue) -> Union[str, None]:
"""
used as a fallback for OCTET STRING when no formatter is found/defined
"""
Expand All @@ -18,13 +19,13 @@ def on_octet_string(value: bytes) -> str:
return


def on_integer(value: int) -> str:
def on_integer(value: TValue) -> Union[int, None]:
if not isinstance(value, int):
return
return value


def on_oid_map(oid: Tuple[int]) -> str:
def on_oid_map(oid: TValue) -> Union[str, None]:
if not isinstance(oid, tuple):
# some devices don't follow mib's syntax
# for example ipAddressTable.ipAddressPrefix returns an int in case of
Expand All @@ -45,7 +46,7 @@ def on_value_map_b(value: bytes, map_: dict) -> str:
v for k, v in map_.items() if value[k // 8] & (1 << k % 8))


def on_syntax(syntax: dict, value: Union[int, str, bytes]):
def on_syntax(syntax: dict, value: TValue):
"""
this is point where bytes are converted to right datatype
"""
Expand All @@ -65,7 +66,10 @@ def on_syntax(syntax: dict, value: Union[int, str, bytes]):
raise Exception(f'Invalid syntax {syntax}')


def on_result(base_oid: Tuple[int], result: dict) -> Tuple[str, list]:
def on_result(
base_oid: TOid,
result: List[Tuple[TOid, TValue]],
) -> Tuple[str, List[dict]]:
"""returns a more compat result (w/o prefixes) and groups formatted
metrics by base_oid
"""
Expand Down Expand Up @@ -109,7 +113,10 @@ def on_result(base_oid: Tuple[int], result: dict) -> Tuple[str, list]:
return result_name, list(table.values())


def on_result_base(base_oid: Tuple[int], result: dict) -> Tuple[str, list]:
def on_result_base(
base_oid: TOid,
result: List[Tuple[TOid, TValue]],
) -> Tuple[str, List[dict]]:
"""returns formatted metrics grouped by base_oid
"""
base = MIB_INDEX[base_oid]
Expand Down
12 changes: 7 additions & 5 deletions asyncsnmplib/package.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .asn1 import Decoder, Encoder, Number
from typing import Optional, Tuple, List
from .asn1 import Decoder, Encoder, Number, Tag, TOid, TValue


class Package:
Expand All @@ -8,12 +9,13 @@ class Package:
pdu = None

def __init__(self):
self.request_id = None
self.error_status = None
self.error_index = None
self.variable_bindings = []
self.request_id: Optional[int] = None
self.error_status: Optional[int] = None
self.error_index: Optional[int] = None
self.variable_bindings: List[Tuple[TOid, Tag, TValue]] = []

def encode(self):
assert self.pdu is not None
encoder = Encoder()

with encoder.enter(Number.Sequence):
Expand Down
11 changes: 7 additions & 4 deletions asyncsnmplib/protocol.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 50 is a fix.

Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class SnmpProtocol(asyncio.DatagramProtocol):
__slots__ = ('loop', 'target', 'transport', 'requests', '_request_id')

def __init__(self, target):
self.loop = asyncio.get_event_loop()
self.loop = asyncio.get_running_loop()
self.target = target
self.requests = {}
self._request_id = 0
Expand All @@ -47,8 +47,11 @@ def datagram_received(self, data: bytes, *args):
# before request_id is known we cannot do anything and the query
# will time out
pid = pkg.request_id
if pid is not None:
if pid in self.requests:
self.requests[pid].set_exception(exceptions.SnmpDecodeError)
elif pid is not None:
logging.error(
self._log_with_suffix(f'Unknown package pid {pid}'))
else:
logging.error(
self._log_with_suffix('Failed to decode package'))
Expand All @@ -59,9 +62,9 @@ def datagram_received(self, data: bytes, *args):
self._log_with_suffix(f'Unknown package pid {pid}'))
else:
exception = None
if pkg.error_status != 0:
if pkg.error_status: # also exclude None for trap-pdu
oid = None
if pkg.error_index != 0:
if pkg.error_index: # also exclude None for trap-pdu
oidtuple = \
pkg.variable_bindings[pkg.error_index - 1][0]
oid = '.'.join(map(str, oidtuple))
Expand Down
Loading
Loading