Add fold expressions for argument packs

This commit is contained in:
drmortalwombat 2023-09-17 20:34:42 +02:00
parent 38274fb4f7
commit f211eef837
7 changed files with 323 additions and 12 deletions

View File

@ -322,6 +322,12 @@ ostream & ostream::operator<<(bool val)
return *this;
}
ostream & ostream::operator<<(char val)
{
bput(val);
return *this;
}
ostream & ostream::operator<<(int val)
{
if (val < 0)

View File

@ -91,6 +91,7 @@ public:
ostream & write(const char * s, int n);
ostream & operator<<(bool val);
ostream & operator<<(char val);
ostream & operator<<(int val);
ostream & operator<<(unsigned val);
ostream & operator<<(long val);

View File

@ -923,15 +923,30 @@ Declaration* Declaration::BuildConstPointer(const Location& loc)
return pdec;
}
Declaration* Declaration::BuildConstReference(const Location& loc)
Declaration* Declaration::BuildConstReference(const Location& loc, DecType type)
{
Declaration* pdec = new Declaration(loc, DT_TYPE_REFERENCE);
Declaration* pdec = new Declaration(loc, type);
pdec->mBase = this;
pdec->mFlags = DTF_DEFINED | DTF_CONST;
pdec->mSize = 2;
return pdec;
}
Declaration* Declaration::BuildArrayPointer(void)
{
if (mType == DT_TYPE_ARRAY)
{
Declaration* pdec = new Declaration(mLocation, DT_TYPE_POINTER);
pdec->mBase = mBase;
pdec->mFlags = DTF_DEFINED;
pdec->mSize = 2;
pdec->mStride = mStride;
return pdec;
}
else
return this;
}
Declaration* Declaration::BuildPointer(const Location& loc)
{
Declaration* pdec = new Declaration(loc, DT_TYPE_POINTER);
@ -941,9 +956,9 @@ Declaration* Declaration::BuildPointer(const Location& loc)
return pdec;
}
Declaration* Declaration::BuildReference(const Location& loc)
Declaration* Declaration::BuildReference(const Location& loc, DecType type)
{
Declaration* pdec = new Declaration(loc, DT_TYPE_REFERENCE);
Declaration* pdec = new Declaration(loc, type);
pdec->mBase = this;
pdec->mFlags = DTF_DEFINED;
pdec->mSize = 2;
@ -987,7 +1002,7 @@ Declaration* Declaration::DeduceAuto(Declaration * dec)
dec = dec->ToMutableType();
if (IsReference())
dec = dec->BuildReference(mLocation);
dec = dec->BuildReference(mLocation, mType);
return dec;
}
@ -2167,6 +2182,12 @@ bool Declaration::IsReference(void) const
return mType == DT_TYPE_REFERENCE || mType == DT_TYPE_RVALUEREF;
}
bool Declaration::IsIndexed(void) const
{
return mType == DT_TYPE_ARRAY || mType == DT_TYPE_POINTER;
}
bool Declaration::IsSimpleType(void) const
{
return mType == DT_TYPE_INTEGER || mType == DT_TYPE_BOOL || mType == DT_TYPE_FLOAT || mType == DT_TYPE_ENUM || mType == DT_TYPE_POINTER;

View File

@ -299,6 +299,7 @@ public:
bool IsNumericType(void) const;
bool IsSimpleType(void) const;
bool IsReference(void) const;
bool IsIndexed(void) const;
void SetDefined(void);
@ -311,12 +312,13 @@ public:
Declaration* Last(void);
Declaration* BuildPointer(const Location& loc);
Declaration* BuildReference(const Location& loc);
Declaration* BuildReference(const Location& loc, DecType type = DT_TYPE_REFERENCE);
Declaration* BuildConstPointer(const Location& loc);
Declaration* BuildConstReference(const Location& loc);
Declaration* BuildConstReference(const Location& loc, DecType type = DT_TYPE_REFERENCE);
Declaration* BuildRValueRef(const Location& loc);
Declaration* BuildConstRValueRef(const Location& loc);
Declaration* NonRefBase(void);
Declaration* BuildArrayPointer(void);
Declaration* DeduceAuto(Declaration* dec);
bool IsAuto(void) const;

View File

@ -86,6 +86,7 @@ enum ErrorID
EERR_INVALID_BITFIELD,
EERR_INVALID_CAPTURE,
EERR_INVALID_PACK_USAGE,
EERR_INVALID_FOLD_EXPRESSION,
EERR_INVALID_CONSTEXPR,
EERR_DOUBLE_FREE,

View File

@ -1203,9 +1203,9 @@ Declaration * Parser::ParseFunctionDeclaration(Declaration* bdec)
if (adec->mBase->IsReference())
{
if (adec->mBase->mBase->mFlags & DTF_CONST)
apdec->mBase = atdec->ToConstType()->BuildReference(adec->mLocation);
apdec->mBase = atdec->ToConstType()->BuildReference(adec->mLocation, adec->mBase->mType);
else
apdec->mBase = atdec->BuildReference(adec->mLocation);
apdec->mBase = atdec->BuildReference(adec->mLocation, adec->mBase->mType);
}
else
apdec->mBase = atdec;
@ -5009,7 +5009,7 @@ Expression* Parser::ParseSimpleExpression(bool lhs)
case TK_CHARACTER:
dec = new Declaration(mScanner->mLocation, DT_CONST_INTEGER);
dec->mInteger = mCharMap[(unsigned char)mScanner->mTokenInteger];
dec->mBase = TheUnsignedIntTypeDeclaration;
dec->mBase = TheUnsignedCharTypeDeclaration;
exp = new Expression(mScanner->mLocation, EX_CONSTANT);
exp->mDecValue = dec;
exp->mDecType = dec->mBase;
@ -5382,7 +5382,7 @@ Expression* Parser::ParseSimpleExpression(bool lhs)
case TK_OPEN_PARENTHESIS:
mScanner->NextToken();
exp = ParseExpression(true);
exp = ParseListExpression(true);
if (mScanner->mToken == TK_CLOSE_PARENTHESIS)
mScanner->NextToken();
else
@ -6099,6 +6099,9 @@ Expression * Parser::ResolveOverloadCall(Expression* exp, Expression* exp2)
{
if (exp->mType == EX_CALL && exp->mLeft->mDecValue)
{
if (exp->mRight && FindPackExpression(exp->mRight))
return exp;
Declaration* fdec = exp->mLeft->mDecValue;
Declaration* fdec2 = exp2 ? exp2->mLeft->mDecValue : nullptr;
@ -6928,6 +6931,10 @@ Expression* Parser::ParseMulExpression(bool lhs)
nexp->mToken = mScanner->mToken;
nexp->mLeft = exp;
mScanner->NextToken();
if (ConsumeTokenIf(TK_ELLIPSIS))
return ParseBinaryFoldExpression(nexp);
nexp->mRight = ParsePrefixExpression(false);
if (nexp->mLeft->mDecType->mType == DT_TYPE_FLOAT || nexp->mRight->mDecType->mType == DT_TYPE_FLOAT)
@ -6943,6 +6950,192 @@ Expression* Parser::ParseMulExpression(bool lhs)
return exp;
}
Expression* Parser::FindPackExpression(Expression* exp)
{
if (exp->mType == EX_PACK)
return exp;
if (exp->mLeft)
{
Expression* nexp = FindPackExpression(exp->mLeft);
if (nexp)
return nexp;
}
if (exp->mRight)
{
Expression* nexp = FindPackExpression(exp->mRight);
if (nexp)
return nexp;
}
return nullptr;
}
Expression* Parser::ExpandPackExpression(Expression* exp, Expression* pack, Expression* item)
{
if (exp == pack)
return item;
Expression* left = exp->mLeft ? ExpandPackExpression(exp->mLeft, pack, item) : nullptr;
Expression* right = exp->mRight ? ExpandPackExpression(exp->mRight, pack, item) : nullptr;
if (left == exp->mLeft && right == exp->mRight)
return exp;
Expression* nexp = new Expression(exp->mLocation, exp->mType);
nexp->mToken = exp->mToken;
nexp->mLeft = left;
nexp->mRight = right;
nexp->mDecValue = exp->mDecValue;
nexp->mDecType = exp->mDecType;
if (nexp->mType == EX_CALL)
nexp = ResolveOverloadCall(nexp);
return nexp;
}
Declaration* Parser::OperatorResultType(Expression* exp)
{
if (exp->mType == EX_BINARY)
{
Declaration* lt = exp->mLeft->mDecType->NonRefBase();
Declaration* rt = exp->mRight->mDecType->NonRefBase();
switch (exp->mToken)
{
case TK_MUL:
case TK_DIV:
case TK_MOD:
if (lt->mType == DT_TYPE_FLOAT || rt->mType == DT_TYPE_FLOAT)
return TheFloatTypeDeclaration;
else
return lt;
case TK_ADD:
case TK_SUB:
if (lt->IsIndexed() && rt->IsIntegerType())
return lt->BuildArrayPointer();
else if (rt->IsIndexed() && lt->IsIntegerType())
return rt->BuildArrayPointer();
else if (lt->mType == DT_TYPE_FLOAT || rt->mType == DT_TYPE_FLOAT)
return TheFloatTypeDeclaration;
else if (lt->IsIndexed() && rt->IsIndexed())
return TheSignedIntTypeDeclaration;
else
return lt;
case TK_LEFT_SHIFT:
case TK_RIGHT_SHIFT:
case TK_BINARY_AND:
case TK_BINARY_OR:
case TK_BINARY_XOR:
return lt;
}
}
else if (exp->mType == EX_RELATIONAL || exp->mType == EX_LOGICAL_AND || exp->mType == EX_LOGICAL_OR || exp->mType == EX_LOGICAL_NOT)
return TheBoolTypeDeclaration;
else if (exp->mType == EX_LIST)
return exp->mRight->mDecType;
return exp->mLeft->mDecType;
}
Expression* Parser::ParseBinaryFoldExpression(Expression * exp)
{
if (ConsumeTokenIf(exp->mToken))
exp->mRight = ParsePrefixExpression(false);
if (exp->mLeft)
{
Expression* pexp = FindPackExpression(exp->mLeft);
if (pexp)
{
Declaration* dpack = pexp->mDecValue->mParams;
Expression* nexp = exp->mRight;
if (dpack)
{
dpack = dpack->Last();
if (!nexp)
{
Expression* vexp = new Expression(exp->mLocation, EX_VARIABLE);
vexp->mDecType = dpack->mBase;
vexp->mDecValue = dpack;
nexp = ExpandPackExpression(exp->mLeft, pexp, vexp);
dpack = dpack->mPrev;
}
while (dpack)
{
Expression* vexp = new Expression(exp->mLocation, EX_VARIABLE);
vexp->mDecType = dpack->mBase;
vexp->mDecValue = dpack;
Expression* oexp = new Expression(exp->mLocation, exp->mType);
oexp->mToken = exp->mToken;
oexp->mLeft = ExpandPackExpression(exp->mLeft, pexp, vexp);
oexp->mRight = nexp;
oexp->mDecType = OperatorResultType(oexp);
nexp = CheckOperatorOverload(oexp);
nexp = nexp->ConstantFold(mErrors, mDataSection);
dpack = dpack->mPrev;
}
}
return nexp;
}
}
if (exp->mRight)
{
Expression* pexp = FindPackExpression(exp->mRight);
if (pexp)
{
Declaration* dpack = pexp->mDecValue->mParams;
Expression* nexp = exp->mLeft;
if (dpack)
{
if (!nexp)
{
Expression* vexp = new Expression(exp->mLocation, EX_VARIABLE);
vexp->mDecType = dpack->mBase;
vexp->mDecValue = dpack;
nexp = ExpandPackExpression(exp->mRight, pexp, vexp);
dpack = dpack->mNext;
}
while (dpack)
{
Expression* vexp = new Expression(exp->mLocation, EX_VARIABLE);
vexp->mDecType = dpack->mBase;
vexp->mDecValue = dpack;
Expression* oexp = new Expression(exp->mLocation, exp->mType);
oexp->mToken = exp->mToken;
oexp->mLeft = nexp;
oexp->mRight = ExpandPackExpression(exp->mRight, pexp, vexp);
oexp->mDecType = OperatorResultType(oexp);
nexp = CheckOperatorOverload(oexp);
nexp = nexp->ConstantFold(mErrors, mDataSection);
dpack = dpack->mNext;
}
}
return nexp;
}
}
mErrors->Error(exp->mLocation, EERR_INVALID_FOLD_EXPRESSION, "No parameter pack");
return exp;
}
Expression* Parser::ParseAddExpression(bool lhs)
{
Expression* exp = ParseMulExpression(lhs);
@ -6953,6 +7146,10 @@ Expression* Parser::ParseAddExpression(bool lhs)
nexp->mToken = mScanner->mToken;
nexp->mLeft = exp;
mScanner->NextToken();
if (ConsumeTokenIf(TK_ELLIPSIS))
return ParseBinaryFoldExpression(nexp);
nexp->mRight = ParseMulExpression(false);
if (nexp->mLeft->mDecType->mType == DT_TYPE_POINTER && nexp->mRight->mDecType->IsIntegerType())
nexp->mDecType = nexp->mLeft->mDecType;
@ -7001,6 +7198,10 @@ Expression* Parser::ParseShiftExpression(bool lhs)
nexp->mToken = mScanner->mToken;
nexp->mLeft = exp;
mScanner->NextToken();
if (ConsumeTokenIf(TK_ELLIPSIS))
return ParseBinaryFoldExpression(nexp);
nexp->mRight = ParseAddExpression(false);
nexp->mDecType = exp->mDecType;
@ -7022,6 +7223,10 @@ Expression* Parser::ParseRelationalExpression(bool lhs)
nexp->mToken = mScanner->mToken;
nexp->mLeft = exp;
mScanner->NextToken();
if (ConsumeTokenIf(TK_ELLIPSIS))
return ParseBinaryFoldExpression(nexp);
nexp->mRight = ParseShiftExpression(false);
nexp->mDecType = TheBoolTypeDeclaration;
@ -7064,6 +7269,10 @@ Expression* Parser::ParseBinaryXorExpression(bool lhs)
nexp->mToken = mScanner->mToken;
nexp->mLeft = exp;
mScanner->NextToken();
if (ConsumeTokenIf(TK_ELLIPSIS))
return ParseBinaryFoldExpression(nexp);
nexp->mRight = ParseBinaryAndExpression(false);
nexp->mDecType = exp->mDecType;
exp = nexp->ConstantFold(mErrors, mDataSection);
@ -7084,6 +7293,10 @@ Expression* Parser::ParseBinaryOrExpression(bool lhs)
nexp->mToken = mScanner->mToken;
nexp->mLeft = exp;
mScanner->NextToken();
if (ConsumeTokenIf(TK_ELLIPSIS))
return ParseBinaryFoldExpression(nexp);
nexp->mRight = ParseBinaryXorExpression(false);
nexp->mDecType = exp->mDecType;
exp = nexp->ConstantFold(mErrors, mDataSection);
@ -7104,6 +7317,10 @@ Expression* Parser::ParseLogicAndExpression(bool lhs)
nexp->mToken = mScanner->mToken;
nexp->mLeft = CoerceExpression(exp, TheBoolTypeDeclaration);
mScanner->NextToken();
if (ConsumeTokenIf(TK_ELLIPSIS))
return ParseBinaryFoldExpression(nexp);
nexp->mRight = CoerceExpression(ParseBinaryOrExpression(false), TheBoolTypeDeclaration);
nexp->mDecType = TheBoolTypeDeclaration;
exp = nexp->ConstantFold(mErrors, mDataSection);
@ -7122,6 +7339,10 @@ Expression* Parser::ParseLogicOrExpression(bool lhs)
nexp->mToken = mScanner->mToken;
nexp->mLeft = CoerceExpression(exp, TheBoolTypeDeclaration);
mScanner->NextToken();
if (ConsumeTokenIf(TK_ELLIPSIS))
return ParseBinaryFoldExpression(nexp);
nexp->mRight = CoerceExpression(ParseLogicAndExpression(false), TheBoolTypeDeclaration);
nexp->mDecType = TheBoolTypeDeclaration;
exp = nexp->ConstantFold(mErrors, mDataSection);
@ -7748,6 +7969,55 @@ Expression* Parser::ExpandArgumentPack(Expression* exp, Declaration* dec)
Expression* Parser::ParseListExpression(bool lhs)
{
if (ConsumeTokenIf(TK_ELLIPSIS))
{
ExpressionType et = EX_ERROR;
switch (mScanner->mToken)
{
case TK_MUL:
case TK_DIV:
case TK_MOD:
case TK_ADD:
case TK_SUB:
case TK_LEFT_SHIFT:
case TK_RIGHT_SHIFT:
case TK_BINARY_AND:
case TK_BINARY_OR:
case TK_BINARY_XOR:
et = EX_BINARY;
break;
case TK_EQUAL:
case TK_NOT_EQUAL:
case TK_GREATER_THAN:
case TK_GREATER_EQUAL:
case TK_LESS_THAN:
case TK_LESS_EQUAL:
et = EX_RELATIONAL;
break;
case TK_LOGICAL_AND:
et = EX_LOGICAL_AND;
break;
case TK_LOGICAL_OR:
et = EX_LOGICAL_OR;
break;
case TK_LOGICAL_NOT:
et = EX_LOGICAL_NOT;
break;
case TK_COMMA:
et = EX_LIST;
break;
}
if (et == EX_ERROR)
mErrors->Error(mScanner->mLocation, EERR_INVALID_FOLD_EXPRESSION, "Invalid operator in fold expression");
Expression* nexp = new Expression(mScanner->mLocation, et);
nexp->mToken = mScanner->mToken;
mScanner->NextToken();
nexp->mRight = ParseExpression(false);
return ParseBinaryFoldExpression(nexp);
}
Expression* exp = ParseExpression(lhs);
if (exp->mType == EX_PACK)
{
@ -7765,6 +8035,10 @@ Expression* Parser::ParseListExpression(bool lhs)
Expression* nexp = new Expression(mScanner->mLocation, EX_LIST);
nexp->mToken = TK_COMMA;
nexp->mLeft = exp;
if (ConsumeTokenIf(TK_ELLIPSIS))
return ParseBinaryFoldExpression(nexp);
nexp->mRight = ParseListExpression(false);
exp = nexp;
}
@ -9213,7 +9487,7 @@ Expression* Parser::ParseAssemblerBaseOperand(Declaration* pcasm, int pcoffset)
case TK_CHARACTER:
dec = new Declaration(mScanner->mLocation, DT_CONST_INTEGER);
dec->mInteger = mCharMap[(unsigned char)mScanner->mTokenInteger];
dec->mBase = TheUnsignedIntTypeDeclaration;
dec->mBase = TheUnsignedCharTypeDeclaration;
exp = new Expression(mScanner->mLocation, EX_CONSTANT);
exp->mDecValue = dec;
exp->mDecType = dec->mBase;

View File

@ -114,6 +114,12 @@ protected:
Expression* ParseNewOperator(void);
Expression* ParseLambdaExpression(void);
Declaration* OperatorResultType(Expression* exp);
Expression* FindPackExpression(Expression* exp);
Expression* ExpandPackExpression(Expression* exp, Expression* pack, Expression* item);
Expression* ParseBinaryFoldExpression(Expression * exp);
Expression* ParseSimpleExpression(bool lhs);
Expression* ParsePrefixExpression(bool lhs);