Overflow checks for sumof()

Signed-off-by: Demi Marie Obenour <demiobenour@gmail.com>
This commit is contained in:
Demi Marie Obenour 2023-06-01 16:32:46 -04:00
parent 9f0edb0b54
commit 4a8fac2831

View File

@ -881,7 +881,7 @@ def _c_serialize_helper_switch(context, self, complex_name,
code_lines, temp_vars, code_lines, temp_vars,
space, prefix): space, prefix):
count = 0 count = 0
switch_expr = _c_accessor_get_expr(self.expr, None) switch_expr = _c_accessor_get_expr(self.expr, None, False)
for b in self.bitcases: for b in self.bitcases:
len_expr = len(b.type.expr) len_expr = len(b.type.expr)
@ -893,7 +893,7 @@ def _c_serialize_helper_switch(context, self, complex_name,
compare_operator = '&' compare_operator = '&'
for n, expr in enumerate(b.type.expr): for n, expr in enumerate(b.type.expr):
bitcase_expr = _c_accessor_get_expr(expr, None) bitcase_expr = _c_accessor_get_expr(expr, None, False)
# only one <enumref> in the <bitcase> # only one <enumref> in the <bitcase>
if len_expr == 1: if len_expr == 1:
code_lines.append( code_lines.append(
@ -1029,7 +1029,7 @@ def _c_serialize_helper_list_field(context, self, field,
if expr.op == 'calculate_len': if expr.op == 'calculate_len':
list_length = field.type.expr.lenfield_name list_length = field.type.expr.lenfield_name
else: else:
list_length = _c_accessor_get_expr(expr, field_mapping) list_length = _c_accessor_get_expr(expr, field_mapping, False)
# default: list with fixed size elements # default: list with fixed size elements
length = 'xcb_checked_mul(%s, sizeof(%s))' % (list_length, field.type.member.c_wiretype) length = 'xcb_checked_mul(%s, sizeof(%s))' % (list_length, field.type.member.c_wiretype)
@ -1125,10 +1125,10 @@ def _c_serialize_helper_fields_fixed_size(context, self, field,
# need to register a temporary variable for the expression in case we know its type # need to register a temporary variable for the expression in case we know its type
if field.type.c_type is None: if field.type.c_type is None:
raise Exception("type for field '%s' (expression '%s') unkown" % raise Exception("type for field '%s' (expression '%s') unkown" %
(field.field_name, _c_accessor_get_expr(field.type.expr))) (field.field_name, _c_accessor_get_expr(field.type.expr, prefix, False)))
temp_vars.append(' %s xcb_expr_%s = %s;' % (field.type.c_type, _cpp(field.field_name), temp_vars.append(' %s xcb_expr_%s = %s;' % (field.type.c_type, _cpp(field.field_name),
_c_accessor_get_expr(field.type.expr, prefix))) _c_accessor_get_expr(field.type.expr, prefix, False)))
value += "&xcb_expr_%s;" % _cpp(field.field_name) value += "&xcb_expr_%s;" % _cpp(field.field_name)
elif field.type.is_pad: elif field.type.is_pad:
@ -1480,7 +1480,7 @@ def _c_serialize(context, self):
field_mapping = _c_get_field_mapping_for_expr(self, self.length_expr, prefix) field_mapping = _c_get_field_mapping_for_expr(self, self.length_expr, prefix)
_c(' return %s;', _c_accessor_get_expr(self.length_expr, field_mapping)) _c(' return %s;', _c_accessor_get_expr(self.length_expr, field_mapping, context.endswith('checked')))
_c('}') _c('}')
_c_pre.redirect_end() _c_pre.redirect_end()
return return
@ -1732,7 +1732,7 @@ def _c_accessor_get_length(expr, field_mapping=None):
else: else:
return str(expr.nmemb) return str(expr.nmemb)
def _c_accessor_get_expr(expr, field_mapping): def _c_accessor_get_expr(expr, field_mapping, checked):
''' '''
Figures out what C code is needed to get the length of a list field. Figures out what C code is needed to get the length of a list field.
The field_mapping parameter can be used to change the absolute name of a length field. The field_mapping parameter can be used to change the absolute name of a length field.
@ -1743,9 +1743,9 @@ def _c_accessor_get_expr(expr, field_mapping):
lenexp = _c_accessor_get_length(expr, field_mapping) lenexp = _c_accessor_get_length(expr, field_mapping)
if expr.op == '~': # SAFE if expr.op == '~': # SAFE
return '(' + '~' + _c_accessor_get_expr(expr.rhs, field_mapping) + ')' return '(' + '~' + _c_accessor_get_expr(expr.rhs, field_mapping, checked) + ')'
elif expr.op == 'popcount': # SAFE elif expr.op == 'popcount': # SAFE
return 'xcb_popcount(' + _c_accessor_get_expr(expr.rhs, field_mapping) + ')' return 'xcb_popcount(' + _c_accessor_get_expr(expr.rhs, field_mapping, checked) + ')'
elif expr.op == 'enumref': # SAFE elif expr.op == 'enumref': # SAFE
enum_name = expr.lenfield_type.name enum_name = expr.lenfield_type.name
constant_name = expr.lenfield_name constant_name = expr.lenfield_name
@ -1756,7 +1756,7 @@ def _c_accessor_get_expr(expr, field_mapping):
field = expr.lenfield field = expr.lenfield
list_name = field_mapping[field.c_field_name][0] list_name = field_mapping[field.c_field_name][0]
c_length_func = "%s(%s)" % (field.c_length_name, list_name) c_length_func = "%s(%s)" % (field.c_length_name, list_name)
c_length_func = _c_accessor_get_expr(field.type.expr, field_mapping) c_length_func = _c_accessor_get_expr(field.type.expr, field_mapping, checked)
# create explicit code for computing the sum. # create explicit code for computing the sum.
# This works for all C-types which can be added to int64_t with += # This works for all C-types which can be added to int64_t with +=
_c_pre.start() _c_pre.start()
@ -1801,11 +1801,15 @@ def _c_accessor_get_expr(expr, field_mapping):
# cause pre-code of the subexpression be added right here # cause pre-code of the subexpression be added right here
_c_pre.end() _c_pre.end()
# compute the subexpression # compute the subexpression
rhs_expr_str = _c_accessor_get_expr(expr.rhs, scoped_field_mapping) rhs_expr_str = _c_accessor_get_expr(expr.rhs, scoped_field_mapping, checked)
# resume with our code # resume with our code
_c_pre.start() _c_pre.start()
# output the summation expression # output the summation expression
_c_pre.code("%s += %s;", sumvar, rhs_expr_str) if checked:
_c_pre.code("if (__builtin_add_overflow(%s, %s, &%s))", sumvar, rhs_expr_str, sumvar)
_c_pre.code(" return -1;")
else:
_c_pre.code("%s += %s;", sumvar, rhs_expr_str)
_c_pre.code("%s++;", listvar) _c_pre.code("%s++;", listvar)
_c_pre.pop_indent() _c_pre.pop_indent()
@ -1816,9 +1820,9 @@ def _c_accessor_get_expr(expr, field_mapping):
elif expr.op == 'listelement-ref': elif expr.op == 'listelement-ref':
return '(*xcb_listelement)' return '(*xcb_listelement)'
elif expr.op != None and expr.op != 'calculate_len': elif expr.op != None and expr.op != 'calculate_len':
return ('(' + _c_accessor_get_expr(expr.lhs, field_mapping) + return ('(' + _c_accessor_get_expr(expr.lhs, field_mapping, checked) +
' ' + expr.op + ' ' + ' ' + expr.op + ' ' +
_c_accessor_get_expr(expr.rhs, field_mapping) + ')') _c_accessor_get_expr(expr.rhs, field_mapping, checked) + ')')
elif expr.bitfield: elif expr.bitfield:
return 'xcb_popcount(' + lenexp + ')' return 'xcb_popcount(' + lenexp + ')'
else: else:
@ -2030,7 +2034,7 @@ def _c_accessors_list(self, field):
(field.c_field_name) (field.c_field_name)
); );
else: else:
return _c_accessor_get_expr(field.type.expr, fields) return _c_accessor_get_expr(field.type.expr, fields, False)
_c(' return %s;', get_length()) _c(' return %s;', get_length())
_c('}') _c('}')
@ -2435,7 +2439,7 @@ def _c_request_helper(self, name, void, regular, aux=False, reply_fds=False):
if not field.type.is_list: if not field.type.is_list:
num_fds_fixed += 1 num_fds_fixed += 1
else: else:
num_fds_expr.append(_c_accessor_get_expr(field.type.expr, None)) num_fds_expr.append(_c_accessor_get_expr(field.type.expr, None, True))
if list_with_var_size_elems or len(num_fds_expr) > 0: if list_with_var_size_elems or len(num_fds_expr) > 0:
_c(' unsigned int i;') _c(' unsigned int i;')
@ -2455,7 +2459,7 @@ def _c_request_helper(self, name, void, regular, aux=False, reply_fds=False):
for field in wire_fields: for field in wire_fields:
if field.type.fixed_size(): if field.type.fixed_size():
if field.type.is_expr: if field.type.is_expr:
_c(' xcb_out.%s = %s;', field.c_field_name, _c_accessor_get_expr(field.type.expr, None)) _c(' xcb_out.%s = %s;', field.c_field_name, _c_accessor_get_expr(field.type.expr, None, True))
elif field.type.is_pad: elif field.type.is_pad:
if field.type.nmemb == 1: if field.type.nmemb == 1:
_c(' xcb_out.%s = 0;', field.c_field_name) _c(' xcb_out.%s = 0;', field.c_field_name)
@ -2496,12 +2500,12 @@ def _c_request_helper(self, name, void, regular, aux=False, reply_fds=False):
if field.type.expr.op == 'calculate_len': if field.type.expr.op == 'calculate_len':
lenfield = field.type.expr.lenfield_name lenfield = field.type.expr.lenfield_name
else: else:
lenfield = _c_accessor_get_expr(field.type.expr, None) lenfield = _c_accessor_get_expr(field.type.expr, None, False)
_c(' xcb_parts[%d].iov_len = %s * sizeof(%s);', count, lenfield, _c(' xcb_parts[%d].iov_len = %s * sizeof(%s);', count, lenfield,
field.type.member.c_wiretype) field.type.member.c_wiretype)
else: else:
list_length = _c_accessor_get_expr(field.type.expr, None) list_length = _c_accessor_get_expr(field.type.expr, None, False)
length = '' length = ''
_c(" xcb_parts[%d].iov_len = 0;" % count) _c(" xcb_parts[%d].iov_len = 0;" % count)
@ -2559,7 +2563,7 @@ def _c_request_helper(self, name, void, regular, aux=False, reply_fds=False):
if not field.type.is_list: if not field.type.is_list:
_c(' fds[fd_index++] = %s;', field.c_field_name) _c(' fds[fd_index++] = %s;', field.c_field_name)
else: else:
_c(' for (i = 0; i < %s; i++)', _c_accessor_get_expr(field.type.expr, None)) _c(' for (i = 0; i < %s; i++)', _c_accessor_get_expr(field.type.expr, None, False))
_c(' fds[fd_index++] = %s[i];', field.c_field_name) _c(' fds[fd_index++] = %s[i];', field.c_field_name)
if not num_fds: if not num_fds: