Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
torvalds
GitHub Repository: torvalds/linux
Path: blob/master/tools/net/ynl/pyynl/lib/ynl.py
29274 views
1
# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3
from collections import namedtuple
4
from enum import Enum
5
import functools
6
import os
7
import random
8
import socket
9
import struct
10
from struct import Struct
11
import sys
12
import ipaddress
13
import uuid
14
import queue
15
import selectors
16
import time
17
18
from .nlspec import SpecFamily
19
20
#
21
# Generic Netlink code which should really be in some library, but I can't quickly find one.
22
#
23
24
25
class Netlink:
26
# Netlink socket
27
SOL_NETLINK = 270
28
29
NETLINK_ADD_MEMBERSHIP = 1
30
NETLINK_CAP_ACK = 10
31
NETLINK_EXT_ACK = 11
32
NETLINK_GET_STRICT_CHK = 12
33
34
# Netlink message
35
NLMSG_ERROR = 2
36
NLMSG_DONE = 3
37
38
NLM_F_REQUEST = 1
39
NLM_F_ACK = 4
40
NLM_F_ROOT = 0x100
41
NLM_F_MATCH = 0x200
42
43
NLM_F_REPLACE = 0x100
44
NLM_F_EXCL = 0x200
45
NLM_F_CREATE = 0x400
46
NLM_F_APPEND = 0x800
47
48
NLM_F_CAPPED = 0x100
49
NLM_F_ACK_TLVS = 0x200
50
51
NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
52
53
NLA_F_NESTED = 0x8000
54
NLA_F_NET_BYTEORDER = 0x4000
55
56
NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
57
58
# Genetlink defines
59
NETLINK_GENERIC = 16
60
61
GENL_ID_CTRL = 0x10
62
63
# nlctrl
64
CTRL_CMD_GETFAMILY = 3
65
66
CTRL_ATTR_FAMILY_ID = 1
67
CTRL_ATTR_FAMILY_NAME = 2
68
CTRL_ATTR_MAXATTR = 5
69
CTRL_ATTR_MCAST_GROUPS = 7
70
71
CTRL_ATTR_MCAST_GRP_NAME = 1
72
CTRL_ATTR_MCAST_GRP_ID = 2
73
74
# Extack types
75
NLMSGERR_ATTR_MSG = 1
76
NLMSGERR_ATTR_OFFS = 2
77
NLMSGERR_ATTR_COOKIE = 3
78
NLMSGERR_ATTR_POLICY = 4
79
NLMSGERR_ATTR_MISS_TYPE = 5
80
NLMSGERR_ATTR_MISS_NEST = 6
81
82
# Policy types
83
NL_POLICY_TYPE_ATTR_TYPE = 1
84
NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
85
NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
86
NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
87
NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
88
NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
89
NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
90
NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
91
NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
92
NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
93
NL_POLICY_TYPE_ATTR_PAD = 11
94
NL_POLICY_TYPE_ATTR_MASK = 12
95
96
AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
97
's8', 's16', 's32', 's64',
98
'binary', 'string', 'nul-string',
99
'nested', 'nested-array',
100
'bitfield32', 'sint', 'uint'])
101
102
class NlError(Exception):
103
def __init__(self, nl_msg):
104
self.nl_msg = nl_msg
105
self.error = -nl_msg.error
106
107
def __str__(self):
108
return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}"
109
110
111
class ConfigError(Exception):
112
pass
113
114
115
class NlAttr:
116
ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
117
type_formats = {
118
'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")),
119
's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")),
120
'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
121
's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
122
'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
123
's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
124
'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
125
's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
126
}
127
128
def __init__(self, raw, offset):
129
self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
130
self.type = self._type & ~Netlink.NLA_TYPE_MASK
131
self.is_nest = self._type & Netlink.NLA_F_NESTED
132
self.payload_len = self._len
133
self.full_len = (self.payload_len + 3) & ~3
134
self.raw = raw[offset + 4 : offset + self.payload_len]
135
136
@classmethod
137
def get_format(cls, attr_type, byte_order=None):
138
format = cls.type_formats[attr_type]
139
if byte_order:
140
return format.big if byte_order == "big-endian" \
141
else format.little
142
return format.native
143
144
def as_scalar(self, attr_type, byte_order=None):
145
format = self.get_format(attr_type, byte_order)
146
return format.unpack(self.raw)[0]
147
148
def as_auto_scalar(self, attr_type, byte_order=None):
149
if len(self.raw) != 4 and len(self.raw) != 8:
150
raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
151
real_type = attr_type[0] + str(len(self.raw) * 8)
152
format = self.get_format(real_type, byte_order)
153
return format.unpack(self.raw)[0]
154
155
def as_strz(self):
156
return self.raw.decode('ascii')[:-1]
157
158
def as_bin(self):
159
return self.raw
160
161
def as_c_array(self, type):
162
format = self.get_format(type)
163
return [ x[0] for x in format.iter_unpack(self.raw) ]
164
165
def __repr__(self):
166
return f"[type:{self.type} len:{self._len}] {self.raw}"
167
168
169
class NlAttrs:
170
def __init__(self, msg, offset=0):
171
self.attrs = []
172
173
while offset < len(msg):
174
attr = NlAttr(msg, offset)
175
offset += attr.full_len
176
self.attrs.append(attr)
177
178
def __iter__(self):
179
yield from self.attrs
180
181
def __repr__(self):
182
msg = ''
183
for a in self.attrs:
184
if msg:
185
msg += '\n'
186
msg += repr(a)
187
return msg
188
189
190
class NlMsg:
191
def __init__(self, msg, offset, attr_space=None):
192
self.hdr = msg[offset : offset + 16]
193
194
self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
195
struct.unpack("IHHII", self.hdr)
196
197
self.raw = msg[offset + 16 : offset + self.nl_len]
198
199
self.error = 0
200
self.done = 0
201
202
extack_off = None
203
if self.nl_type == Netlink.NLMSG_ERROR:
204
self.error = struct.unpack("i", self.raw[0:4])[0]
205
self.done = 1
206
extack_off = 20
207
elif self.nl_type == Netlink.NLMSG_DONE:
208
self.error = struct.unpack("i", self.raw[0:4])[0]
209
self.done = 1
210
extack_off = 4
211
212
self.extack = None
213
if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
214
self.extack = dict()
215
extack_attrs = NlAttrs(self.raw[extack_off:])
216
for extack in extack_attrs:
217
if extack.type == Netlink.NLMSGERR_ATTR_MSG:
218
self.extack['msg'] = extack.as_strz()
219
elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
220
self.extack['miss-type'] = extack.as_scalar('u32')
221
elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
222
self.extack['miss-nest'] = extack.as_scalar('u32')
223
elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
224
self.extack['bad-attr-offs'] = extack.as_scalar('u32')
225
elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
226
self.extack['policy'] = self._decode_policy(extack.raw)
227
else:
228
if 'unknown' not in self.extack:
229
self.extack['unknown'] = []
230
self.extack['unknown'].append(extack)
231
232
if attr_space:
233
self.annotate_extack(attr_space)
234
235
def _decode_policy(self, raw):
236
policy = {}
237
for attr in NlAttrs(raw):
238
if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
239
type = attr.as_scalar('u32')
240
policy['type'] = Netlink.AttrType(type).name
241
elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
242
policy['min-value'] = attr.as_scalar('s64')
243
elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
244
policy['max-value'] = attr.as_scalar('s64')
245
elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
246
policy['min-value'] = attr.as_scalar('u64')
247
elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
248
policy['max-value'] = attr.as_scalar('u64')
249
elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
250
policy['min-length'] = attr.as_scalar('u32')
251
elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
252
policy['max-length'] = attr.as_scalar('u32')
253
elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
254
policy['bitfield32-mask'] = attr.as_scalar('u32')
255
elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
256
policy['mask'] = attr.as_scalar('u64')
257
return policy
258
259
def annotate_extack(self, attr_space):
260
""" Make extack more human friendly with attribute information """
261
262
# We don't have the ability to parse nests yet, so only do global
263
if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
264
miss_type = self.extack['miss-type']
265
if miss_type in attr_space.attrs_by_val:
266
spec = attr_space.attrs_by_val[miss_type]
267
self.extack['miss-type'] = spec['name']
268
if 'doc' in spec:
269
self.extack['miss-type-doc'] = spec['doc']
270
271
def cmd(self):
272
return self.nl_type
273
274
def __repr__(self):
275
msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
276
if self.error:
277
msg += '\n\terror: ' + str(self.error)
278
if self.extack:
279
msg += '\n\textack: ' + repr(self.extack)
280
return msg
281
282
283
class NlMsgs:
284
def __init__(self, data):
285
self.msgs = []
286
287
offset = 0
288
while offset < len(data):
289
msg = NlMsg(data, offset)
290
offset += msg.nl_len
291
self.msgs.append(msg)
292
293
def __iter__(self):
294
yield from self.msgs
295
296
297
genl_family_name_to_id = None
298
299
300
def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
301
# we prepend length in _genl_msg_finalize()
302
if seq is None:
303
seq = random.randint(1, 1024)
304
nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
305
genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
306
return nlmsg + genlmsg
307
308
309
def _genl_msg_finalize(msg):
310
return struct.pack("I", len(msg) + 4) + msg
311
312
313
def _genl_load_families():
314
with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
315
sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
316
317
msg = _genl_msg(Netlink.GENL_ID_CTRL,
318
Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
319
Netlink.CTRL_CMD_GETFAMILY, 1)
320
msg = _genl_msg_finalize(msg)
321
322
sock.send(msg, 0)
323
324
global genl_family_name_to_id
325
genl_family_name_to_id = dict()
326
327
while True:
328
reply = sock.recv(128 * 1024)
329
nms = NlMsgs(reply)
330
for nl_msg in nms:
331
if nl_msg.error:
332
print("Netlink error:", nl_msg.error)
333
return
334
if nl_msg.done:
335
return
336
337
gm = GenlMsg(nl_msg)
338
fam = dict()
339
for attr in NlAttrs(gm.raw):
340
if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
341
fam['id'] = attr.as_scalar('u16')
342
elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
343
fam['name'] = attr.as_strz()
344
elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
345
fam['maxattr'] = attr.as_scalar('u32')
346
elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
347
fam['mcast'] = dict()
348
for entry in NlAttrs(attr.raw):
349
mcast_name = None
350
mcast_id = None
351
for entry_attr in NlAttrs(entry.raw):
352
if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
353
mcast_name = entry_attr.as_strz()
354
elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
355
mcast_id = entry_attr.as_scalar('u32')
356
if mcast_name and mcast_id is not None:
357
fam['mcast'][mcast_name] = mcast_id
358
if 'name' in fam and 'id' in fam:
359
genl_family_name_to_id[fam['name']] = fam
360
361
362
class GenlMsg:
363
def __init__(self, nl_msg):
364
self.nl = nl_msg
365
self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
366
self.raw = nl_msg.raw[4:]
367
368
def cmd(self):
369
return self.genl_cmd
370
371
def __repr__(self):
372
msg = repr(self.nl)
373
msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
374
for a in self.raw_attrs:
375
msg += '\t\t' + repr(a) + '\n'
376
return msg
377
378
379
class NetlinkProtocol:
380
def __init__(self, family_name, proto_num):
381
self.family_name = family_name
382
self.proto_num = proto_num
383
384
def _message(self, nl_type, nl_flags, seq=None):
385
if seq is None:
386
seq = random.randint(1, 1024)
387
nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
388
return nlmsg
389
390
def message(self, flags, command, version, seq=None):
391
return self._message(command, flags, seq)
392
393
def _decode(self, nl_msg):
394
return nl_msg
395
396
def decode(self, ynl, nl_msg, op):
397
msg = self._decode(nl_msg)
398
if op is None:
399
op = ynl.rsp_by_value[msg.cmd()]
400
fixed_header_size = ynl._struct_size(op.fixed_header)
401
msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
402
return msg
403
404
def get_mcast_id(self, mcast_name, mcast_groups):
405
if mcast_name not in mcast_groups:
406
raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
407
return mcast_groups[mcast_name].value
408
409
def msghdr_size(self):
410
return 16
411
412
413
class GenlProtocol(NetlinkProtocol):
414
def __init__(self, family_name):
415
super().__init__(family_name, Netlink.NETLINK_GENERIC)
416
417
global genl_family_name_to_id
418
if genl_family_name_to_id is None:
419
_genl_load_families()
420
421
self.genl_family = genl_family_name_to_id[family_name]
422
self.family_id = genl_family_name_to_id[family_name]['id']
423
424
def message(self, flags, command, version, seq=None):
425
nlmsg = self._message(self.family_id, flags, seq)
426
genlmsg = struct.pack("BBH", command, version, 0)
427
return nlmsg + genlmsg
428
429
def _decode(self, nl_msg):
430
return GenlMsg(nl_msg)
431
432
def get_mcast_id(self, mcast_name, mcast_groups):
433
if mcast_name not in self.genl_family['mcast']:
434
raise Exception(f'Multicast group "{mcast_name}" not present in the family')
435
return self.genl_family['mcast'][mcast_name]
436
437
def msghdr_size(self):
438
return super().msghdr_size() + 4
439
440
441
class SpaceAttrs:
442
SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
443
444
def __init__(self, attr_space, attrs, outer = None):
445
outer_scopes = outer.scopes if outer else []
446
inner_scope = self.SpecValuesPair(attr_space, attrs)
447
self.scopes = [inner_scope] + outer_scopes
448
449
def lookup(self, name):
450
for scope in self.scopes:
451
if name in scope.spec:
452
if name in scope.values:
453
return scope.values[name]
454
spec_name = scope.spec.yaml['name']
455
raise Exception(
456
f"No value for '{name}' in attribute space '{spec_name}'")
457
raise Exception(f"Attribute '{name}' not defined in any attribute-set")
458
459
460
#
461
# YNL implementation details.
462
#
463
464
465
class YnlFamily(SpecFamily):
466
def __init__(self, def_path, schema=None, process_unknown=False,
467
recv_size=0):
468
super().__init__(def_path, schema)
469
470
self.include_raw = False
471
self.process_unknown = process_unknown
472
473
try:
474
if self.proto == "netlink-raw":
475
self.nlproto = NetlinkProtocol(self.yaml['name'],
476
self.yaml['protonum'])
477
else:
478
self.nlproto = GenlProtocol(self.yaml['name'])
479
except KeyError:
480
raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
481
482
self._recv_dbg = False
483
# Note that netlink will use conservative (min) message size for
484
# the first dump recv() on the socket, our setting will only matter
485
# from the second recv() on.
486
self._recv_size = recv_size if recv_size else 131072
487
# Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
488
# for a message, so smaller receive sizes will lead to truncation.
489
# Note that the min size for other families may be larger than 4k!
490
if self._recv_size < 4000:
491
raise ConfigError()
492
493
self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
494
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
495
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
496
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
497
498
self.async_msg_ids = set()
499
self.async_msg_queue = queue.Queue()
500
501
for msg in self.msgs.values():
502
if msg.is_async:
503
self.async_msg_ids.add(msg.rsp_value)
504
505
for op_name, op in self.ops.items():
506
bound_f = functools.partial(self._op, op_name)
507
setattr(self, op.ident_name, bound_f)
508
509
510
def ntf_subscribe(self, mcast_name):
511
mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
512
self.sock.bind((0, 0))
513
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
514
mcast_id)
515
516
def set_recv_dbg(self, enabled):
517
self._recv_dbg = enabled
518
519
def _recv_dbg_print(self, reply, nl_msgs):
520
if not self._recv_dbg:
521
return
522
print("Recv: read", len(reply), "bytes,",
523
len(nl_msgs.msgs), "messages", file=sys.stderr)
524
for nl_msg in nl_msgs:
525
print(" ", nl_msg, file=sys.stderr)
526
527
def _encode_enum(self, attr_spec, value):
528
enum = self.consts[attr_spec['enum']]
529
if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
530
scalar = 0
531
if isinstance(value, str):
532
value = [value]
533
for single_value in value:
534
scalar += enum.entries[single_value].user_value(as_flags = True)
535
return scalar
536
else:
537
return enum.entries[value].user_value()
538
539
def _get_scalar(self, attr_spec, value):
540
try:
541
return int(value)
542
except (ValueError, TypeError) as e:
543
if 'enum' in attr_spec:
544
return self._encode_enum(attr_spec, value)
545
if attr_spec.display_hint:
546
return self._from_string(value, attr_spec)
547
raise e
548
549
def _add_attr(self, space, name, value, search_attrs):
550
try:
551
attr = self.attr_sets[space][name]
552
except KeyError:
553
raise Exception(f"Space '{space}' has no attribute '{name}'")
554
nl_type = attr.value
555
556
if attr.is_multi and isinstance(value, list):
557
attr_payload = b''
558
for subvalue in value:
559
attr_payload += self._add_attr(space, name, subvalue, search_attrs)
560
return attr_payload
561
562
if attr["type"] == 'nest':
563
nl_type |= Netlink.NLA_F_NESTED
564
sub_space = attr['nested-attributes']
565
attr_payload = self._add_nest_attrs(value, sub_space, search_attrs)
566
elif attr['type'] == 'indexed-array' and attr['sub-type'] == 'nest':
567
nl_type |= Netlink.NLA_F_NESTED
568
sub_space = attr['nested-attributes']
569
attr_payload = self._encode_indexed_array(value, sub_space,
570
search_attrs)
571
elif attr["type"] == 'flag':
572
if not value:
573
# If value is absent or false then skip attribute creation.
574
return b''
575
attr_payload = b''
576
elif attr["type"] == 'string':
577
attr_payload = str(value).encode('ascii') + b'\x00'
578
elif attr["type"] == 'binary':
579
if value is None:
580
attr_payload = b''
581
elif isinstance(value, bytes):
582
attr_payload = value
583
elif isinstance(value, str):
584
if attr.display_hint:
585
attr_payload = self._from_string(value, attr)
586
else:
587
attr_payload = bytes.fromhex(value)
588
elif isinstance(value, dict) and attr.struct_name:
589
attr_payload = self._encode_struct(attr.struct_name, value)
590
elif isinstance(value, list) and attr.sub_type in NlAttr.type_formats:
591
format = NlAttr.get_format(attr.sub_type)
592
attr_payload = b''.join([format.pack(x) for x in value])
593
else:
594
raise Exception(f'Unknown type for binary attribute, value: {value}')
595
elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
596
scalar = self._get_scalar(attr, value)
597
if attr.is_auto_scalar:
598
attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
599
else:
600
attr_type = attr["type"]
601
format = NlAttr.get_format(attr_type, attr.byte_order)
602
attr_payload = format.pack(scalar)
603
elif attr['type'] in "bitfield32":
604
scalar_value = self._get_scalar(attr, value["value"])
605
scalar_selector = self._get_scalar(attr, value["selector"])
606
attr_payload = struct.pack("II", scalar_value, scalar_selector)
607
elif attr['type'] == 'sub-message':
608
msg_format, _ = self._resolve_selector(attr, search_attrs)
609
attr_payload = b''
610
if msg_format.fixed_header:
611
attr_payload += self._encode_struct(msg_format.fixed_header, value)
612
if msg_format.attr_set:
613
if msg_format.attr_set in self.attr_sets:
614
nl_type |= Netlink.NLA_F_NESTED
615
sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
616
for subname, subvalue in value.items():
617
attr_payload += self._add_attr(msg_format.attr_set,
618
subname, subvalue, sub_attrs)
619
else:
620
raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
621
else:
622
raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
623
624
return self._add_attr_raw(nl_type, attr_payload)
625
626
def _add_attr_raw(self, nl_type, attr_payload):
627
pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
628
return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
629
630
def _add_nest_attrs(self, value, sub_space, search_attrs):
631
sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs)
632
attr_payload = b''
633
for subname, subvalue in value.items():
634
attr_payload += self._add_attr(sub_space, subname, subvalue,
635
sub_attrs)
636
return attr_payload
637
638
def _encode_indexed_array(self, vals, sub_space, search_attrs):
639
attr_payload = b''
640
for i, val in enumerate(vals):
641
idx = i | Netlink.NLA_F_NESTED
642
val_payload = self._add_nest_attrs(val, sub_space, search_attrs)
643
attr_payload += self._add_attr_raw(idx, val_payload)
644
return attr_payload
645
646
def _get_enum_or_unknown(self, enum, raw):
647
try:
648
name = enum.entries_by_val[raw].name
649
except KeyError as error:
650
if self.process_unknown:
651
name = f"Unknown({raw})"
652
else:
653
raise error
654
return name
655
656
def _decode_enum(self, raw, attr_spec):
657
enum = self.consts[attr_spec['enum']]
658
if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
659
i = 0
660
value = set()
661
while raw:
662
if raw & 1:
663
value.add(self._get_enum_or_unknown(enum, i))
664
raw >>= 1
665
i += 1
666
else:
667
value = self._get_enum_or_unknown(enum, raw)
668
return value
669
670
def _decode_binary(self, attr, attr_spec):
671
if attr_spec.struct_name:
672
decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
673
elif attr_spec.sub_type:
674
decoded = attr.as_c_array(attr_spec.sub_type)
675
if 'enum' in attr_spec:
676
decoded = [ self._decode_enum(x, attr_spec) for x in decoded ]
677
elif attr_spec.display_hint:
678
decoded = [ self._formatted_string(x, attr_spec.display_hint)
679
for x in decoded ]
680
else:
681
decoded = attr.as_bin()
682
if attr_spec.display_hint:
683
decoded = self._formatted_string(decoded, attr_spec.display_hint)
684
return decoded
685
686
def _decode_array_attr(self, attr, attr_spec):
687
decoded = []
688
offset = 0
689
while offset < len(attr.raw):
690
item = NlAttr(attr.raw, offset)
691
offset += item.full_len
692
693
if attr_spec["sub-type"] == 'nest':
694
subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
695
decoded.append({ item.type: subattrs })
696
elif attr_spec["sub-type"] == 'binary':
697
subattr = item.as_bin()
698
if attr_spec.display_hint:
699
subattr = self._formatted_string(subattr, attr_spec.display_hint)
700
decoded.append(subattr)
701
elif attr_spec["sub-type"] in NlAttr.type_formats:
702
subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
703
if 'enum' in attr_spec:
704
subattr = self._decode_enum(subattr, attr_spec)
705
elif attr_spec.display_hint:
706
subattr = self._formatted_string(subattr, attr_spec.display_hint)
707
decoded.append(subattr)
708
else:
709
raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
710
return decoded
711
712
def _decode_nest_type_value(self, attr, attr_spec):
713
decoded = {}
714
value = attr
715
for name in attr_spec['type-value']:
716
value = NlAttr(value.raw, 0)
717
decoded[name] = value.type
718
subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
719
decoded.update(subattrs)
720
return decoded
721
722
def _decode_unknown(self, attr):
723
if attr.is_nest:
724
return self._decode(NlAttrs(attr.raw), None)
725
else:
726
return attr.as_bin()
727
728
def _rsp_add(self, rsp, name, is_multi, decoded):
729
if is_multi is None:
730
if name in rsp and type(rsp[name]) is not list:
731
rsp[name] = [rsp[name]]
732
is_multi = True
733
else:
734
is_multi = False
735
736
if not is_multi:
737
rsp[name] = decoded
738
elif name in rsp:
739
rsp[name].append(decoded)
740
else:
741
rsp[name] = [decoded]
742
743
def _resolve_selector(self, attr_spec, search_attrs):
744
sub_msg = attr_spec.sub_message
745
if sub_msg not in self.sub_msgs:
746
raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
747
sub_msg_spec = self.sub_msgs[sub_msg]
748
749
selector = attr_spec.selector
750
value = search_attrs.lookup(selector)
751
if value not in sub_msg_spec.formats:
752
raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
753
754
spec = sub_msg_spec.formats[value]
755
return spec, value
756
757
def _decode_sub_msg(self, attr, attr_spec, search_attrs):
758
msg_format, _ = self._resolve_selector(attr_spec, search_attrs)
759
decoded = {}
760
offset = 0
761
if msg_format.fixed_header:
762
decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header))
763
offset = self._struct_size(msg_format.fixed_header)
764
if msg_format.attr_set:
765
if msg_format.attr_set in self.attr_sets:
766
subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
767
decoded.update(subdict)
768
else:
769
raise Exception(f"Unknown attribute-set '{msg_format.attr_set}' when decoding '{attr_spec.name}'")
770
return decoded
771
772
def _decode(self, attrs, space, outer_attrs = None):
773
rsp = dict()
774
if space:
775
attr_space = self.attr_sets[space]
776
search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
777
778
for attr in attrs:
779
try:
780
attr_spec = attr_space.attrs_by_val[attr.type]
781
except (KeyError, UnboundLocalError):
782
if not self.process_unknown:
783
raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
784
attr_name = f"UnknownAttr({attr.type})"
785
self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
786
continue
787
788
try:
789
if attr_spec["type"] == 'nest':
790
subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
791
decoded = subdict
792
elif attr_spec["type"] == 'string':
793
decoded = attr.as_strz()
794
elif attr_spec["type"] == 'binary':
795
decoded = self._decode_binary(attr, attr_spec)
796
elif attr_spec["type"] == 'flag':
797
decoded = True
798
elif attr_spec.is_auto_scalar:
799
decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
800
if 'enum' in attr_spec:
801
decoded = self._decode_enum(decoded, attr_spec)
802
elif attr_spec["type"] in NlAttr.type_formats:
803
decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
804
if 'enum' in attr_spec:
805
decoded = self._decode_enum(decoded, attr_spec)
806
elif attr_spec.display_hint:
807
decoded = self._formatted_string(decoded, attr_spec.display_hint)
808
elif attr_spec["type"] == 'indexed-array':
809
decoded = self._decode_array_attr(attr, attr_spec)
810
elif attr_spec["type"] == 'bitfield32':
811
value, selector = struct.unpack("II", attr.raw)
812
if 'enum' in attr_spec:
813
value = self._decode_enum(value, attr_spec)
814
selector = self._decode_enum(selector, attr_spec)
815
decoded = {"value": value, "selector": selector}
816
elif attr_spec["type"] == 'sub-message':
817
decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
818
elif attr_spec["type"] == 'nest-type-value':
819
decoded = self._decode_nest_type_value(attr, attr_spec)
820
else:
821
if not self.process_unknown:
822
raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
823
decoded = self._decode_unknown(attr)
824
825
self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
826
except:
827
print(f"Error decoding '{attr_spec.name}' from '{space}'")
828
raise
829
830
return rsp
831
832
def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs):
833
for attr in attrs:
834
try:
835
attr_spec = attr_set.attrs_by_val[attr.type]
836
except KeyError:
837
raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
838
if offset > target:
839
break
840
if offset == target:
841
return '.' + attr_spec.name
842
843
if offset + attr.full_len <= target:
844
offset += attr.full_len
845
continue
846
847
pathname = attr_spec.name
848
if attr_spec['type'] == 'nest':
849
sub_attrs = self.attr_sets[attr_spec['nested-attributes']]
850
search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name']))
851
elif attr_spec['type'] == 'sub-message':
852
msg_format, value = self._resolve_selector(attr_spec, search_attrs)
853
if msg_format is None:
854
raise Exception(f"Can't resolve sub-message of {attr_spec['name']} for extack")
855
sub_attrs = self.attr_sets[msg_format.attr_set]
856
pathname += f"({value})"
857
else:
858
raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
859
offset += 4
860
subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs,
861
offset, target, search_attrs)
862
if subpath is None:
863
return None
864
return '.' + pathname + subpath
865
866
return None
867
868
def _decode_extack(self, request, op, extack, vals):
869
if 'bad-attr-offs' not in extack:
870
return
871
872
msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
873
offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
874
search_attrs = SpaceAttrs(op.attr_set, vals)
875
path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
876
extack['bad-attr-offs'], search_attrs)
877
if path:
878
del extack['bad-attr-offs']
879
extack['bad-attr'] = path
880
881
def _struct_size(self, name):
882
if name:
883
members = self.consts[name].members
884
size = 0
885
for m in members:
886
if m.type in ['pad', 'binary']:
887
if m.struct:
888
size += self._struct_size(m.struct)
889
else:
890
size += m.len
891
else:
892
format = NlAttr.get_format(m.type, m.byte_order)
893
size += format.size
894
return size
895
else:
896
return 0
897
898
def _decode_struct(self, data, name):
899
members = self.consts[name].members
900
attrs = dict()
901
offset = 0
902
for m in members:
903
value = None
904
if m.type == 'pad':
905
offset += m.len
906
elif m.type == 'binary':
907
if m.struct:
908
len = self._struct_size(m.struct)
909
value = self._decode_struct(data[offset : offset + len],
910
m.struct)
911
offset += len
912
else:
913
value = data[offset : offset + m.len]
914
offset += m.len
915
else:
916
format = NlAttr.get_format(m.type, m.byte_order)
917
[ value ] = format.unpack_from(data, offset)
918
offset += format.size
919
if value is not None:
920
if m.enum:
921
value = self._decode_enum(value, m)
922
elif m.display_hint:
923
value = self._formatted_string(value, m.display_hint)
924
attrs[m.name] = value
925
return attrs
926
927
def _encode_struct(self, name, vals):
928
members = self.consts[name].members
929
attr_payload = b''
930
for m in members:
931
value = vals.pop(m.name) if m.name in vals else None
932
if m.type == 'pad':
933
attr_payload += bytearray(m.len)
934
elif m.type == 'binary':
935
if m.struct:
936
if value is None:
937
value = dict()
938
attr_payload += self._encode_struct(m.struct, value)
939
else:
940
if value is None:
941
attr_payload += bytearray(m.len)
942
else:
943
attr_payload += bytes.fromhex(value)
944
else:
945
if value is None:
946
value = 0
947
format = NlAttr.get_format(m.type, m.byte_order)
948
attr_payload += format.pack(value)
949
return attr_payload
950
951
def _formatted_string(self, raw, display_hint):
952
if display_hint == 'mac':
953
formatted = ':'.join('%02x' % b for b in raw)
954
elif display_hint == 'hex':
955
if isinstance(raw, int):
956
formatted = hex(raw)
957
else:
958
formatted = bytes.hex(raw, ' ')
959
elif display_hint in [ 'ipv4', 'ipv6', 'ipv4-or-v6' ]:
960
formatted = format(ipaddress.ip_address(raw))
961
elif display_hint == 'uuid':
962
formatted = str(uuid.UUID(bytes=raw))
963
else:
964
formatted = raw
965
return formatted
966
967
def _from_string(self, string, attr_spec):
968
if attr_spec.display_hint in ['ipv4', 'ipv6', 'ipv4-or-v6']:
969
ip = ipaddress.ip_address(string)
970
if attr_spec['type'] == 'binary':
971
raw = ip.packed
972
else:
973
raw = int(ip)
974
elif attr_spec.display_hint == 'hex':
975
if attr_spec['type'] == 'binary':
976
raw = bytes.fromhex(string)
977
else:
978
raw = int(string, 16)
979
else:
980
raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented"
981
f" when parsing '{attr_spec['name']}'")
982
return raw
983
984
def handle_ntf(self, decoded):
985
msg = dict()
986
if self.include_raw:
987
msg['raw'] = decoded
988
op = self.rsp_by_value[decoded.cmd()]
989
attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
990
if op.fixed_header:
991
attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
992
993
msg['name'] = op['name']
994
msg['msg'] = attrs
995
self.async_msg_queue.put(msg)
996
997
def check_ntf(self):
998
while True:
999
try:
1000
reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
1001
except BlockingIOError:
1002
return
1003
1004
nms = NlMsgs(reply)
1005
self._recv_dbg_print(reply, nms)
1006
for nl_msg in nms:
1007
if nl_msg.error:
1008
print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
1009
print(nl_msg)
1010
continue
1011
if nl_msg.done:
1012
print("Netlink done while checking for ntf!?")
1013
continue
1014
1015
decoded = self.nlproto.decode(self, nl_msg, None)
1016
if decoded.cmd() not in self.async_msg_ids:
1017
print("Unexpected msg id while checking for ntf", decoded)
1018
continue
1019
1020
self.handle_ntf(decoded)
1021
1022
def poll_ntf(self, duration=None):
1023
start_time = time.time()
1024
selector = selectors.DefaultSelector()
1025
selector.register(self.sock, selectors.EVENT_READ)
1026
1027
while True:
1028
try:
1029
yield self.async_msg_queue.get_nowait()
1030
except queue.Empty:
1031
if duration is not None:
1032
timeout = start_time + duration - time.time()
1033
if timeout <= 0:
1034
return
1035
else:
1036
timeout = None
1037
events = selector.select(timeout)
1038
if events:
1039
self.check_ntf()
1040
1041
def operation_do_attributes(self, name):
1042
"""
1043
For a given operation name, find and return a supported
1044
set of attributes (as a dict).
1045
"""
1046
op = self.find_operation(name)
1047
if not op:
1048
return None
1049
1050
return op['do']['request']['attributes'].copy()
1051
1052
def _encode_message(self, op, vals, flags, req_seq):
1053
nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
1054
for flag in flags or []:
1055
nl_flags |= flag
1056
1057
msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
1058
if op.fixed_header:
1059
msg += self._encode_struct(op.fixed_header, vals)
1060
search_attrs = SpaceAttrs(op.attr_set, vals)
1061
for name, value in vals.items():
1062
msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
1063
msg = _genl_msg_finalize(msg)
1064
return msg
1065
1066
def _ops(self, ops):
1067
reqs_by_seq = {}
1068
req_seq = random.randint(1024, 65535)
1069
payload = b''
1070
for (method, vals, flags) in ops:
1071
op = self.ops[method]
1072
msg = self._encode_message(op, vals, flags, req_seq)
1073
reqs_by_seq[req_seq] = (op, vals, msg, flags)
1074
payload += msg
1075
req_seq += 1
1076
1077
self.sock.send(payload, 0)
1078
1079
done = False
1080
rsp = []
1081
op_rsp = []
1082
while not done:
1083
reply = self.sock.recv(self._recv_size)
1084
nms = NlMsgs(reply)
1085
self._recv_dbg_print(reply, nms)
1086
for nl_msg in nms:
1087
if nl_msg.nl_seq in reqs_by_seq:
1088
(op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
1089
if nl_msg.extack:
1090
nl_msg.annotate_extack(op.attr_set)
1091
self._decode_extack(req_msg, op, nl_msg.extack, vals)
1092
else:
1093
op = None
1094
req_flags = []
1095
1096
if nl_msg.error:
1097
raise NlError(nl_msg)
1098
if nl_msg.done:
1099
if nl_msg.extack:
1100
print("Netlink warning:")
1101
print(nl_msg)
1102
1103
if Netlink.NLM_F_DUMP in req_flags:
1104
rsp.append(op_rsp)
1105
elif not op_rsp:
1106
rsp.append(None)
1107
elif len(op_rsp) == 1:
1108
rsp.append(op_rsp[0])
1109
else:
1110
rsp.append(op_rsp)
1111
op_rsp = []
1112
1113
del reqs_by_seq[nl_msg.nl_seq]
1114
done = len(reqs_by_seq) == 0
1115
break
1116
1117
decoded = self.nlproto.decode(self, nl_msg, op)
1118
1119
# Check if this is a reply to our request
1120
if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1121
if decoded.cmd() in self.async_msg_ids:
1122
self.handle_ntf(decoded)
1123
continue
1124
else:
1125
print('Unexpected message: ' + repr(decoded))
1126
continue
1127
1128
rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1129
if op.fixed_header:
1130
rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1131
op_rsp.append(rsp_msg)
1132
1133
return rsp
1134
1135
def _op(self, method, vals, flags=None, dump=False):
1136
req_flags = flags or []
1137
if dump:
1138
req_flags.append(Netlink.NLM_F_DUMP)
1139
1140
ops = [(method, vals, req_flags)]
1141
return self._ops(ops)[0]
1142
1143
def do(self, method, vals, flags=None):
1144
return self._op(method, vals, flags)
1145
1146
def dump(self, method, vals):
1147
return self._op(method, vals, dump=True)
1148
1149
def do_multi(self, ops):
1150
return self._ops(ops)
1151
1152