Overflow checks for sumof()

Signed-off-by: Demi Marie Obenour <demiobenour@gmail.com>
This commit is contained in:
Demi Marie Obenour 2023-06-01 16:32:46 -04:00
parent 9f0edb0b54
commit 4a8fac2831

View File

@ -881,7 +881,7 @@ def _c_serialize_helper_switch(context, self, complex_name,
code_lines, temp_vars,
space, prefix):
count = 0
switch_expr = _c_accessor_get_expr(self.expr, None)
switch_expr = _c_accessor_get_expr(self.expr, None, False)
for b in self.bitcases:
len_expr = len(b.type.expr)
@ -893,7 +893,7 @@ def _c_serialize_helper_switch(context, self, complex_name,
compare_operator = '&'
for n, expr in enumerate(b.type.expr):
bitcase_expr = _c_accessor_get_expr(expr, None)
bitcase_expr = _c_accessor_get_expr(expr, None, False)
# only one <enumref> in the <bitcase>
if len_expr == 1:
code_lines.append(
@ -1029,7 +1029,7 @@ def _c_serialize_helper_list_field(context, self, field,
if expr.op == 'calculate_len':
list_length = field.type.expr.lenfield_name
else:
list_length = _c_accessor_get_expr(expr, field_mapping)
list_length = _c_accessor_get_expr(expr, field_mapping, False)
# default: list with fixed size elements
length = 'xcb_checked_mul(%s, sizeof(%s))' % (list_length, field.type.member.c_wiretype)
@ -1125,10 +1125,10 @@ def _c_serialize_helper_fields_fixed_size(context, self, field,
# need to register a temporary variable for the expression in case we know its type
if field.type.c_type is None:
raise Exception("type for field '%s' (expression '%s') unkown" %
(field.field_name, _c_accessor_get_expr(field.type.expr)))
(field.field_name, _c_accessor_get_expr(field.type.expr, prefix, False)))
temp_vars.append(' %s xcb_expr_%s = %s;' % (field.type.c_type, _cpp(field.field_name),
_c_accessor_get_expr(field.type.expr, prefix)))
_c_accessor_get_expr(field.type.expr, prefix, False)))
value += "&xcb_expr_%s;" % _cpp(field.field_name)
elif field.type.is_pad:
@ -1480,7 +1480,7 @@ def _c_serialize(context, self):
field_mapping = _c_get_field_mapping_for_expr(self, self.length_expr, prefix)
_c(' return %s;', _c_accessor_get_expr(self.length_expr, field_mapping))
_c(' return %s;', _c_accessor_get_expr(self.length_expr, field_mapping, context.endswith('checked')))
_c('}')
_c_pre.redirect_end()
return
@ -1732,7 +1732,7 @@ def _c_accessor_get_length(expr, field_mapping=None):
else:
return str(expr.nmemb)
def _c_accessor_get_expr(expr, field_mapping):
def _c_accessor_get_expr(expr, field_mapping, checked):
'''
Figures out what C code is needed to get the length of a list field.
The field_mapping parameter can be used to change the absolute name of a length field.
@ -1743,9 +1743,9 @@ def _c_accessor_get_expr(expr, field_mapping):
lenexp = _c_accessor_get_length(expr, field_mapping)
if expr.op == '~': # SAFE
return '(' + '~' + _c_accessor_get_expr(expr.rhs, field_mapping) + ')'
return '(' + '~' + _c_accessor_get_expr(expr.rhs, field_mapping, checked) + ')'
elif expr.op == 'popcount': # SAFE
return 'xcb_popcount(' + _c_accessor_get_expr(expr.rhs, field_mapping) + ')'
return 'xcb_popcount(' + _c_accessor_get_expr(expr.rhs, field_mapping, checked) + ')'
elif expr.op == 'enumref': # SAFE
enum_name = expr.lenfield_type.name
constant_name = expr.lenfield_name
@ -1756,7 +1756,7 @@ def _c_accessor_get_expr(expr, field_mapping):
field = expr.lenfield
list_name = field_mapping[field.c_field_name][0]
c_length_func = "%s(%s)" % (field.c_length_name, list_name)
c_length_func = _c_accessor_get_expr(field.type.expr, field_mapping)
c_length_func = _c_accessor_get_expr(field.type.expr, field_mapping, checked)
# create explicit code for computing the sum.
# This works for all C-types which can be added to int64_t with +=
_c_pre.start()
@ -1801,10 +1801,14 @@ def _c_accessor_get_expr(expr, field_mapping):
# cause pre-code of the subexpression be added right here
_c_pre.end()
# compute the subexpression
rhs_expr_str = _c_accessor_get_expr(expr.rhs, scoped_field_mapping)
rhs_expr_str = _c_accessor_get_expr(expr.rhs, scoped_field_mapping, checked)
# resume with our code
_c_pre.start()
# output the summation expression
if checked:
_c_pre.code("if (__builtin_add_overflow(%s, %s, &%s))", sumvar, rhs_expr_str, sumvar)
_c_pre.code(" return -1;")
else:
_c_pre.code("%s += %s;", sumvar, rhs_expr_str)
_c_pre.code("%s++;", listvar)
@ -1816,9 +1820,9 @@ def _c_accessor_get_expr(expr, field_mapping):
elif expr.op == 'listelement-ref':
return '(*xcb_listelement)'
elif expr.op != None and expr.op != 'calculate_len':
return ('(' + _c_accessor_get_expr(expr.lhs, field_mapping) +
return ('(' + _c_accessor_get_expr(expr.lhs, field_mapping, checked) +
' ' + expr.op + ' ' +
_c_accessor_get_expr(expr.rhs, field_mapping) + ')')
_c_accessor_get_expr(expr.rhs, field_mapping, checked) + ')')
elif expr.bitfield:
return 'xcb_popcount(' + lenexp + ')'
else:
@ -2030,7 +2034,7 @@ def _c_accessors_list(self, field):
(field.c_field_name)
);
else:
return _c_accessor_get_expr(field.type.expr, fields)
return _c_accessor_get_expr(field.type.expr, fields, False)
_c(' return %s;', get_length())
_c('}')
@ -2435,7 +2439,7 @@ def _c_request_helper(self, name, void, regular, aux=False, reply_fds=False):
if not field.type.is_list:
num_fds_fixed += 1
else:
num_fds_expr.append(_c_accessor_get_expr(field.type.expr, None))
num_fds_expr.append(_c_accessor_get_expr(field.type.expr, None, True))
if list_with_var_size_elems or len(num_fds_expr) > 0:
_c(' unsigned int i;')
@ -2455,7 +2459,7 @@ def _c_request_helper(self, name, void, regular, aux=False, reply_fds=False):
for field in wire_fields:
if field.type.fixed_size():
if field.type.is_expr:
_c(' xcb_out.%s = %s;', field.c_field_name, _c_accessor_get_expr(field.type.expr, None))
_c(' xcb_out.%s = %s;', field.c_field_name, _c_accessor_get_expr(field.type.expr, None, True))
elif field.type.is_pad:
if field.type.nmemb == 1:
_c(' xcb_out.%s = 0;', field.c_field_name)
@ -2496,12 +2500,12 @@ def _c_request_helper(self, name, void, regular, aux=False, reply_fds=False):
if field.type.expr.op == 'calculate_len':
lenfield = field.type.expr.lenfield_name
else:
lenfield = _c_accessor_get_expr(field.type.expr, None)
lenfield = _c_accessor_get_expr(field.type.expr, None, False)
_c(' xcb_parts[%d].iov_len = %s * sizeof(%s);', count, lenfield,
field.type.member.c_wiretype)
else:
list_length = _c_accessor_get_expr(field.type.expr, None)
list_length = _c_accessor_get_expr(field.type.expr, None, False)
length = ''
_c(" xcb_parts[%d].iov_len = 0;" % count)
@ -2559,7 +2563,7 @@ def _c_request_helper(self, name, void, regular, aux=False, reply_fds=False):
if not field.type.is_list:
_c(' fds[fd_index++] = %s;', field.c_field_name)
else:
_c(' for (i = 0; i < %s; i++)', _c_accessor_get_expr(field.type.expr, None))
_c(' for (i = 0; i < %s; i++)', _c_accessor_get_expr(field.type.expr, None, False))
_c(' fds[fd_index++] = %s[i];', field.c_field_name)
if not num_fds: