From 4a8fac28319c43ccf7d621a1654e288fe4aae65d Mon Sep 17 00:00:00 2001 From: Demi Marie Obenour Date: Thu, 1 Jun 2023 16:32:46 -0400 Subject: [PATCH] Overflow checks for sumof() Signed-off-by: Demi Marie Obenour --- src/c_client.py | 44 ++++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/src/c_client.py b/src/c_client.py index cb51227..f8cb2b3 100644 --- a/src/c_client.py +++ b/src/c_client.py @@ -881,7 +881,7 @@ def _c_serialize_helper_switch(context, self, complex_name, code_lines, temp_vars, space, prefix): count = 0 - switch_expr = _c_accessor_get_expr(self.expr, None) + switch_expr = _c_accessor_get_expr(self.expr, None, False) for b in self.bitcases: len_expr = len(b.type.expr) @@ -893,7 +893,7 @@ def _c_serialize_helper_switch(context, self, complex_name, compare_operator = '&' for n, expr in enumerate(b.type.expr): - bitcase_expr = _c_accessor_get_expr(expr, None) + bitcase_expr = _c_accessor_get_expr(expr, None, False) # only one in the if len_expr == 1: code_lines.append( @@ -1029,7 +1029,7 @@ def _c_serialize_helper_list_field(context, self, field, if expr.op == 'calculate_len': list_length = field.type.expr.lenfield_name else: - list_length = _c_accessor_get_expr(expr, field_mapping) + list_length = _c_accessor_get_expr(expr, field_mapping, False) # default: list with fixed size elements length = 'xcb_checked_mul(%s, sizeof(%s))' % (list_length, field.type.member.c_wiretype) @@ -1125,10 +1125,10 @@ def _c_serialize_helper_fields_fixed_size(context, self, field, # need to register a temporary variable for the expression in case we know its type if field.type.c_type is None: raise Exception("type for field '%s' (expression '%s') unkown" % - (field.field_name, _c_accessor_get_expr(field.type.expr))) + (field.field_name, _c_accessor_get_expr(field.type.expr, prefix, False))) temp_vars.append(' %s xcb_expr_%s = %s;' % (field.type.c_type, _cpp(field.field_name), - _c_accessor_get_expr(field.type.expr, prefix))) + _c_accessor_get_expr(field.type.expr, prefix, False))) value += "&xcb_expr_%s;" % _cpp(field.field_name) elif field.type.is_pad: @@ -1480,7 +1480,7 @@ def _c_serialize(context, self): field_mapping = _c_get_field_mapping_for_expr(self, self.length_expr, prefix) - _c(' return %s;', _c_accessor_get_expr(self.length_expr, field_mapping)) + _c(' return %s;', _c_accessor_get_expr(self.length_expr, field_mapping, context.endswith('checked'))) _c('}') _c_pre.redirect_end() return @@ -1732,7 +1732,7 @@ def _c_accessor_get_length(expr, field_mapping=None): else: return str(expr.nmemb) -def _c_accessor_get_expr(expr, field_mapping): +def _c_accessor_get_expr(expr, field_mapping, checked): ''' Figures out what C code is needed to get the length of a list field. The field_mapping parameter can be used to change the absolute name of a length field. @@ -1743,9 +1743,9 @@ def _c_accessor_get_expr(expr, field_mapping): lenexp = _c_accessor_get_length(expr, field_mapping) if expr.op == '~': # SAFE - return '(' + '~' + _c_accessor_get_expr(expr.rhs, field_mapping) + ')' + return '(' + '~' + _c_accessor_get_expr(expr.rhs, field_mapping, checked) + ')' elif expr.op == 'popcount': # SAFE - return 'xcb_popcount(' + _c_accessor_get_expr(expr.rhs, field_mapping) + ')' + return 'xcb_popcount(' + _c_accessor_get_expr(expr.rhs, field_mapping, checked) + ')' elif expr.op == 'enumref': # SAFE enum_name = expr.lenfield_type.name constant_name = expr.lenfield_name @@ -1756,7 +1756,7 @@ def _c_accessor_get_expr(expr, field_mapping): field = expr.lenfield list_name = field_mapping[field.c_field_name][0] c_length_func = "%s(%s)" % (field.c_length_name, list_name) - c_length_func = _c_accessor_get_expr(field.type.expr, field_mapping) + c_length_func = _c_accessor_get_expr(field.type.expr, field_mapping, checked) # create explicit code for computing the sum. # This works for all C-types which can be added to int64_t with += _c_pre.start() @@ -1801,11 +1801,15 @@ def _c_accessor_get_expr(expr, field_mapping): # cause pre-code of the subexpression be added right here _c_pre.end() # compute the subexpression - rhs_expr_str = _c_accessor_get_expr(expr.rhs, scoped_field_mapping) + rhs_expr_str = _c_accessor_get_expr(expr.rhs, scoped_field_mapping, checked) # resume with our code _c_pre.start() # output the summation expression - _c_pre.code("%s += %s;", sumvar, rhs_expr_str) + if checked: + _c_pre.code("if (__builtin_add_overflow(%s, %s, &%s))", sumvar, rhs_expr_str, sumvar) + _c_pre.code(" return -1;") + else: + _c_pre.code("%s += %s;", sumvar, rhs_expr_str) _c_pre.code("%s++;", listvar) _c_pre.pop_indent() @@ -1816,9 +1820,9 @@ def _c_accessor_get_expr(expr, field_mapping): elif expr.op == 'listelement-ref': return '(*xcb_listelement)' elif expr.op != None and expr.op != 'calculate_len': - return ('(' + _c_accessor_get_expr(expr.lhs, field_mapping) + + return ('(' + _c_accessor_get_expr(expr.lhs, field_mapping, checked) + ' ' + expr.op + ' ' + - _c_accessor_get_expr(expr.rhs, field_mapping) + ')') + _c_accessor_get_expr(expr.rhs, field_mapping, checked) + ')') elif expr.bitfield: return 'xcb_popcount(' + lenexp + ')' else: @@ -2030,7 +2034,7 @@ def _c_accessors_list(self, field): (field.c_field_name) ); else: - return _c_accessor_get_expr(field.type.expr, fields) + return _c_accessor_get_expr(field.type.expr, fields, False) _c(' return %s;', get_length()) _c('}') @@ -2435,7 +2439,7 @@ def _c_request_helper(self, name, void, regular, aux=False, reply_fds=False): if not field.type.is_list: num_fds_fixed += 1 else: - num_fds_expr.append(_c_accessor_get_expr(field.type.expr, None)) + num_fds_expr.append(_c_accessor_get_expr(field.type.expr, None, True)) if list_with_var_size_elems or len(num_fds_expr) > 0: _c(' unsigned int i;') @@ -2455,7 +2459,7 @@ def _c_request_helper(self, name, void, regular, aux=False, reply_fds=False): for field in wire_fields: if field.type.fixed_size(): if field.type.is_expr: - _c(' xcb_out.%s = %s;', field.c_field_name, _c_accessor_get_expr(field.type.expr, None)) + _c(' xcb_out.%s = %s;', field.c_field_name, _c_accessor_get_expr(field.type.expr, None, True)) elif field.type.is_pad: if field.type.nmemb == 1: _c(' xcb_out.%s = 0;', field.c_field_name) @@ -2496,12 +2500,12 @@ def _c_request_helper(self, name, void, regular, aux=False, reply_fds=False): if field.type.expr.op == 'calculate_len': lenfield = field.type.expr.lenfield_name else: - lenfield = _c_accessor_get_expr(field.type.expr, None) + lenfield = _c_accessor_get_expr(field.type.expr, None, False) _c(' xcb_parts[%d].iov_len = %s * sizeof(%s);', count, lenfield, field.type.member.c_wiretype) else: - list_length = _c_accessor_get_expr(field.type.expr, None) + list_length = _c_accessor_get_expr(field.type.expr, None, False) length = '' _c(" xcb_parts[%d].iov_len = 0;" % count) @@ -2559,7 +2563,7 @@ def _c_request_helper(self, name, void, regular, aux=False, reply_fds=False): if not field.type.is_list: _c(' fds[fd_index++] = %s;', field.c_field_name) else: - _c(' for (i = 0; i < %s; i++)', _c_accessor_get_expr(field.type.expr, None)) + _c(' for (i = 0; i < %s; i++)', _c_accessor_get_expr(field.type.expr, None, False)) _c(' fds[fd_index++] = %s[i];', field.c_field_name) if not num_fds: