Automatic function template expansion

This commit is contained in:
drmortalwombat 2023-08-12 15:25:30 +02:00
parent 5b7334bb17
commit cb5451b9b9
4 changed files with 197 additions and 41 deletions

View File

@ -10,6 +10,19 @@ DeclarationScope::DeclarationScope(DeclarationScope* parent, ScopeLevel level, c
mHash = nullptr;
}
void DeclarationScope::Clear(void)
{
mHashFill = 0;
if (mHash)
{
for (int i = 0; i < mHashSize; i++)
{
mHash[i].mDec = nullptr;
mHash[i].mIdent = nullptr;
}
}
}
DeclarationScope::~DeclarationScope(void)
{
delete[] mHash;
@ -797,7 +810,7 @@ Declaration::Declaration(const Location& loc, DecType type)
mConst(nullptr), mMutable(nullptr),
mDefaultConstructor(nullptr), mDestructor(nullptr), mCopyConstructor(nullptr), mCopyAssignment(nullptr),
mVectorConstructor(nullptr), mVectorDestructor(nullptr), mVectorCopyConstructor(nullptr), mVectorCopyAssignment(nullptr),
mVTable(nullptr),
mVTable(nullptr), mTemplate(nullptr),
mVarIndex(-1), mLinkerObject(nullptr), mCallers(nullptr), mCalled(nullptr), mAlignment(1),
mInteger(0), mNumber(0), mMinValue(-0x80000000LL), mMaxValue(0x7fffffffLL), mFastCallBase(0), mFastCallSize(0), mStride(0), mStripe(1),
mCompilerOptions(0), mUseCount(0), mTokens(nullptr), mParser(nullptr)
@ -920,6 +933,11 @@ const Ident* Declaration::MangleIdent(void)
else
mMangleIdent = mQualIdent;
if (mTemplate)
{
}
if (mFlags & DTF_CONST)
mMangleIdent = mMangleIdent->PreMangle("const ");
}
@ -927,44 +945,146 @@ const Ident* Declaration::MangleIdent(void)
return mMangleIdent;
}
Declaration* Declaration::TemplateExpand(Declaration* tdec)
bool Declaration::ResolveTemplate(Expression* pexp, Declaration* tdec)
{
if (mType == DT_ARGUMENT)
Declaration* pdec = tdec->mBase->mParams;
while (pexp)
{
Declaration* edec = this->Clone();
edec->mBase = mBase->TemplateExpand(tdec);
if (mNext)
edec->mNext = mNext->TemplateExpand(tdec);
return edec;
Expression* ex = pexp;
if (pexp->mType == EX_LIST)
{
ex = pexp->mLeft;
pexp = pexp->mRight;
}
else
pexp = nullptr;
if (pdec)
{
if (!ResolveTemplate(ex->mDecType, pdec->mBase))
return false;
pdec = pdec->mNext;
}
else
return false;
}
else if (mType == DT_CONST_FUNCTION)
Declaration* ppdec = nullptr;
Declaration* ptdec = tdec->mTemplate->mParams;
while (ptdec)
{
Declaration* edec = this->Clone();
edec->mBase = mBase->TemplateExpand(tdec);
return edec;
Declaration* pdec = mScope->Lookup(ptdec->mIdent);
if (!pdec)
return false;
Declaration * epdec = ptdec->Clone();
epdec->mBase = pdec;
epdec->mFlags |= DTF_DEFINED;
if (ppdec)
ppdec->mNext = epdec;
else
mParams = epdec;
ppdec = epdec;
ptdec = ptdec->mNext;
}
else if (mType == DT_TYPE_FUNCTION)
mScope->Clear();
return true;
}
bool Declaration::CanResolveTemplate(Expression* pexp, Declaration* tdec)
{
Declaration* pdec = tdec->mBase->mParams;
while (pexp)
{
Declaration* edec = this->Clone();
edec->mBase = mBase->TemplateExpand(tdec);
if (edec->mParams)
edec->mParams = mParams->TemplateExpand(tdec);
return edec;
Expression* ex = pexp;
if (pexp->mType == EX_LIST)
{
ex = pexp->mLeft;
pexp = pexp->mRight;
}
else
pexp = nullptr;
if (pdec)
{
if (!ResolveTemplate(ex->mDecType, pdec->mBase))
return false;
pdec = pdec->mNext;
}
else
return false;
}
else if (mType == DT_TYPE_TEMPLATE || mType == DT_CONST_TEMPLATE)
return true;
}
bool Declaration::ResolveTemplate(Declaration* fdec, Declaration* tdec)
{
if (tdec->IsSame(fdec))
return true;
else if (tdec->mType == DT_TYPE_FUNCTION)
{
return tdec->mScope->Lookup(mIdent);
if (fdec->mType == DT_TYPE_FUNCTION)
{
if (fdec->mBase)
{
if (!tdec->mBase || !ResolveTemplate(fdec->mBase, tdec->mBase))
return false;
}
else if (tdec->mBase)
return false;
Declaration* fpdec = fdec->mParams;
Declaration* tpdec = tdec->mParams;
while (fpdec && tpdec)
{
if (!ResolveTemplate(fpdec->mBase, tpdec->mBase))
return false;
fpdec = fpdec->mNext;
tpdec = tpdec->mNext;
}
if (fpdec || tpdec)
return false;
}
else
return false;
return true;
}
else if (mType == DT_TYPE_POINTER || mType == DT_TYPE_ARRAY || mType == DT_TYPE_REFERENCE)
else if (tdec->mType == DT_TYPE_REFERENCE)
{
Declaration* edec = this->Clone();
edec->mBase = mBase->TemplateExpand(tdec);
return edec;
if (fdec->mType == DT_TYPE_REFERENCE)
return ResolveTemplate(fdec->mBase, tdec->mBase);
else
return ResolveTemplate(fdec, tdec->mBase);
}
else if (tdec->mType == DT_TYPE_POINTER)
{
if (fdec->mType == DT_TYPE_POINTER)
return ResolveTemplate(fdec->mBase, tdec->mBase);
else
return false;
}
else if (tdec->mType == DT_TYPE_TEMPLATE)
{
Declaration* pdec = mScope->Insert(tdec->mIdent, fdec);
if (pdec && !pdec->IsSame(fdec))
return false;
return true;
}
else
return this;
return tdec->CanAssign(fdec);
}
Declaration* Declaration::Clone(void)
{
Declaration* ndec = new Declaration(mLocation, mType);

View File

@ -136,6 +136,7 @@ public:
Declaration* Lookup(const Ident* ident, ScopeLevel limit = SLEVEL_GLOBAL);
void End(const Location & loc);
void Clear(void);
void UseScope(DeclarationScope* scope);
@ -250,7 +251,7 @@ public:
Declaration * mBase, * mParams, * mNext, * mPrev, * mConst, * mMutable;
Declaration * mDefaultConstructor, * mDestructor, * mCopyConstructor, * mCopyAssignment;
Declaration * mVectorConstructor, * mVectorDestructor, * mVectorCopyConstructor, * mVectorCopyAssignment;
Declaration * mVTable, * mClass;
Declaration * mVTable, * mClass, * mTemplate;
Expression* mValue;
DeclarationScope* mScope;
@ -296,7 +297,9 @@ public:
Declaration* BuildConstPointer(const Location& loc);
Declaration* BuildConstReference(const Location& loc);
Declaration* TemplateExpand(Declaration* tdec);
bool CanResolveTemplate(Expression* pexp, Declaration* tdec);
bool ResolveTemplate(Declaration* fdec, Declaration * tdec);
bool ResolveTemplate(Expression* pexp, Declaration* tdec);
const Ident* MangleIdent(void);

View File

@ -132,7 +132,7 @@ Declaration* Parser::ParseStructDeclaration(uint64 flags, DecType dt)
structName = mScanner->mTokenIdent;
mScanner->NextToken();
Declaration* edec = mScope->Lookup(structName);
if (edec && edec->mType == DT_TEMPLATE)
if (edec && edec->mTemplate)
{
mTemplateScope->Insert(structName, dec);
@ -596,8 +596,8 @@ Declaration* Parser::ParseBaseTypeDeclaration(uint64 flags, bool qualified)
mScanner->NextToken();
if (dec && dec->mType == DT_TEMPLATE)
dec = ParseTemplateExpansion(dec, nullptr);
if (dec && dec->mTemplate)
dec = ParseTemplateExpansion(dec->mTemplate, nullptr);
while (qualified && dec && dec->mType == DT_TYPE_STRUCT && ConsumeTokenIf(TK_COLCOLON))
{
@ -846,7 +846,7 @@ Declaration* Parser::ParsePostfixDeclaration(void)
if (mScanner->mToken == TK_LESS_THAN && mTemplateScope)
{
Declaration* tdec = mScope->Lookup(dec->mIdent);
if (tdec && tdec->mType == DT_TEMPLATE)
if (tdec && tdec->mTemplate)
{
// for now just skip over template stuff
while (!ConsumeTokenIf(TK_GREATER_THAN))
@ -3730,7 +3730,11 @@ Declaration* Parser::ParseDeclaration(Declaration * pdec, bool variable, bool ex
if (ndec->mIdent == ndec->mQualIdent)
{
Declaration* ldec = mScope->Insert(ndec->mIdent, pdec ? pdec : ndec);
if (ldec && ldec != pdec)
if (ldec && ldec->mTemplate && mTemplateScope)
{
ndec->mQualIdent = ndec->mQualIdent->Mangle(mTemplateScope->mName->mString);
}
else if (ldec && ldec != pdec)
mErrors->Error(ndec->mLocation, EERR_DUPLICATE_DEFINITION, "Duplicate definition");
}
else if (!pdec)
@ -4463,8 +4467,8 @@ Expression* Parser::ParseSimpleExpression(bool lhs)
dec = ParseQualIdent();
if (dec)
{
if (dec->mType == DT_TEMPLATE)
dec = ParseTemplateExpansion(dec, nullptr);
if (dec->mTemplate && mScanner->mToken == TK_LESS_THAN)
dec = ParseTemplateExpansion(dec->mTemplate, nullptr);
if (dec->mType == DT_CONST_INTEGER || dec->mType == DT_CONST_FLOAT || dec->mType == DT_CONST_FUNCTION || dec->mType == DT_CONST_ASSEMBLER || dec->mType == DT_LABEL || dec->mType == DT_LABEL_REF)
{
@ -4758,7 +4762,17 @@ static const int NOOVERLOAD = 0x7fffffff;
int Parser::OverloadDistance(Declaration* fdec, Expression* pexp)
{
Declaration* pdec = fdec->mParams;
if (fdec->mTemplate)
{
Declaration* tdec = new Declaration(mScanner->mLocation, DT_TEMPLATE);
tdec->mScope = new DeclarationScope(nullptr, SLEVEL_TEMPLATE);
if (tdec->CanResolveTemplate(pexp, fdec))
return 16;
else
return NOOVERLOAD;
}
Declaration* pdec = fdec->mBase->mParams;
int dist = 0;
@ -4884,7 +4898,7 @@ int Parser::OverloadDistance(Declaration* fdec, Expression* pexp)
pdec = pdec->mNext;
}
else if (fdec->mFlags & DTF_VARIADIC)
else if (fdec->mBase->mFlags & DTF_VARIADIC)
{
dist += 1024;
break;
@ -5052,6 +5066,20 @@ Expression* Parser::CoerceExpression(Expression* exp, Declaration* type)
return exp;
}
void Parser::ExpandFunctionCallTemplate(Expression* exp)
{
if (exp->mLeft->mDecValue->mTemplate)
{
Declaration* tdec = new Declaration(mScanner->mLocation, DT_TEMPLATE);
tdec->mScope = new DeclarationScope(nullptr, SLEVEL_TEMPLATE);
if (tdec->ResolveTemplate(exp->mRight, exp->mLeft->mDecValue))
{
exp->mLeft->mDecValue = ParseTemplateExpansion(exp->mLeft->mDecValue->mTemplate, tdec);
exp->mLeft->mDecType = exp->mLeft->mDecValue->mBase;
}
}
}
void Parser::CompleteFunctionDefaultParams(Expression* exp)
{
Declaration* fdec = exp->mLeft->mDecValue;
@ -5110,7 +5138,7 @@ Expression * Parser::ResolveOverloadCall(Expression* exp, Expression* exp2)
while (fdec)
{
int d = OverloadDistance(fdec->mBase, exp->mRight);
int d = OverloadDistance(fdec, exp->mRight);
if (d < ibest)
{
dbest = fdec;
@ -5125,7 +5153,7 @@ Expression * Parser::ResolveOverloadCall(Expression* exp, Expression* exp2)
while (fdec2)
{
int d = OverloadDistance(fdec2->mBase, exp2->mRight);
int d = OverloadDistance(fdec2, exp2->mRight);
if (d < ibest)
{
dbest = fdec2;
@ -5156,6 +5184,7 @@ Expression * Parser::ResolveOverloadCall(Expression* exp, Expression* exp2)
}
}
ExpandFunctionCallTemplate(exp);
CompleteFunctionDefaultParams(exp);
}
@ -7209,9 +7238,9 @@ Declaration* Parser::ParseTemplateExpansion(Declaration* tmpld, Declaration* exp
if (!(mdec->mFlags & DTF_DEFINED))
{
Declaration* mpdec = mScope->Lookup(tmpld->mScope->Mangle(mdec->mIdent));
if (mpdec && mpdec->mType == DT_TEMPLATE)
if (mpdec && mpdec->mTemplate)
{
p->ParseTemplateExpansion(mpdec, tdec);
p->ParseTemplateExpansion(mpdec->mTemplate, tdec);
}
}
@ -7316,6 +7345,7 @@ void Parser::ParseTemplateDeclaration(void)
// Class template
Declaration* bdec = new Declaration(mScanner->mLocation, DT_TYPE_STRUCT);
tdec->mBase = bdec;
bdec->mTemplate = tdec;
mScanner->NextToken();
if (mScanner->mToken == TK_IDENT)
@ -7364,6 +7394,7 @@ void Parser::ParseTemplateDeclaration(void)
adec->mFlags |= DTF_NATIVE;
tdec->mBase = adec;
adec->mTemplate = tdec;
if (ConsumeTokenIf(TK_OPEN_BRACE))
{
@ -7389,7 +7420,8 @@ void Parser::ParseTemplateDeclaration(void)
tdec->mQualIdent = tdec->mBase->mQualIdent;
tdec->mScope->mName = tdec->mQualIdent;
mScope->Insert(tdec->mQualIdent, tdec);
mScope->Insert(tdec->mQualIdent, tdec->mBase);
mCompilationUnits->mScope->Insert(tdec->mQualIdent, tdec->mBase);
}

View File

@ -94,6 +94,7 @@ protected:
Expression* CoerceExpression(Expression* exp, Declaration* type);
bool CanCoerceExpression(Expression* exp, Declaration* type);
void CompleteFunctionDefaultParams(Expression* exp);
void ExpandFunctionCallTemplate(Expression* exp);
void ParseTemplateDeclaration(void);
Declaration* ParseTemplateExpansion(Declaration* tmpld, Declaration* expd);