Protect the non-null taking functions like strcmp.

This commit is contained in:
Baozeng Ding 2013-08-14 20:29:49 +02:00 committed by sftnight
parent b23e62fe11
commit 10602d688f
2 changed files with 150 additions and 5 deletions

View File

@ -57,6 +57,57 @@ extern "C" {
using namespace clang;
namespace cling {
typedef std::map<llvm::StringRef, std::bitset<32> > nonnull_map_t;
// NonNullDeclFinder finds the function decls with nonnull attribute args.
class NonNullDeclFinder
: public RecursiveASTVisitor<NonNullDeclFinder> {
private:
Sema* m_Sema;
llvm::SmallVector<llvm::StringRef, 8> NonNullDeclNames;
nonnull_map_t NonNullArgIndexs;
public:
NonNullDeclFinder(Sema* S) : m_Sema(S) {}
const llvm::SmallVector<llvm::StringRef, 8>& getDeclNames() {
return NonNullDeclNames;
}
const nonnull_map_t& getArgIndexs() {
return NonNullArgIndexs;
}
// Deal with all the call expr in the transaction.
bool VisitCallExpr(CallExpr* TheCall) {
if (FunctionDecl* FDecl = TheCall->getDirectCallee()) {
std::bitset<32> ArgIndexs;
for (specific_attr_iterator<NonNullAttr>
I = FDecl->specific_attr_begin<NonNullAttr>(),
E = FDecl->specific_attr_end<NonNullAttr>(); I != E; ++I) {
NonNullAttr *NonNull = *I;
// Store all the null attr argument's index into "ArgIndexs".
for (NonNullAttr::args_iterator i = NonNull->args_begin(),
e = NonNull->args_end(); i != e; ++i)
ArgIndexs.set(*i);
}
if (ArgIndexs.any()) {
// Get the function decl's name.
llvm::StringRef FName = FDecl->getName();
// Store the function decl's name into the vector.
NonNullDeclNames.push_back(FName);
// Store the function decl's name with its null attr args' indexs
// into the map.
NonNullArgIndexs.insert(std::make_pair(FName, ArgIndexs));
}
}
return true; // returning false will abort the in-depth traversal.
}
};
NullDerefProtectionTransformer::NullDerefProtectionTransformer(Sema *S)
: TransactionTransformer(S), FailBB(0), Builder(0), Inst(0) {}
@ -65,7 +116,6 @@ namespace cling {
{}
void NullDerefProtectionTransformer::Transform() {
using namespace clang;
FunctionDecl* FD = getTransaction()->getWrapperFD();
if (!FD)
return;
@ -105,8 +155,39 @@ namespace cling {
// Find the function in the module.
llvm::Function* F = getTransaction()->getModule()->getFunction(mangledName);
if (F)
runOnFunction(*F);
if (!F) return;
llvm::IRBuilder<> TheBuilder(F->getContext());
Builder = &TheBuilder;
runOnFunction(*F);
NonNullDeclFinder Finder(m_Sema);
// Find all the function decls with null attribute arguments.
for (size_t Idx = 0; Idx < getTransaction()->size(); ++Idx) {
Transaction::DelayCallInfo I = (*getTransaction())[Idx];
for (DeclGroupRef::const_iterator J = I.m_DGR.begin(),
JE = I.m_DGR.end(); J != JE; ++J)
if ((*J)->hasBody())
Finder.TraverseStmt((*J)->getBody());
}
const llvm::SmallVector<llvm::StringRef, 8>&
FDeclNames = Finder.getDeclNames();
if (FDeclNames.empty()) return;
llvm::Module* M = F->getParent();
for (llvm::SmallVector<llvm::StringRef, 8>::const_iterator
i = FDeclNames.begin(), e = FDeclNames.end(); i != e; ++i) {
const nonnull_map_t& ArgIndexs = Finder.getArgIndexs();
nonnull_map_t::const_iterator it = ArgIndexs.find(*i);
if (it != ArgIndexs.end()) {
const std::bitset<32>& ArgNums = it->second;
handleNonNullArgCall(*M, *i, ArgNums);
}
}
}
llvm::BasicBlock*
@ -115,6 +196,7 @@ namespace cling {
llvm::Module* Md = Fn->getParent();
llvm::LLVMContext& ctx = Fn->getContext();
llvm::BasicBlock::iterator PreInsertInst = Builder->GetInsertPoint();
FailBB = llvm::BasicBlock::Create(ctx, "FailBlock", Fn);
Builder->SetInsertPoint(FailBB);
@ -157,6 +239,7 @@ namespace cling {
llvm::BranchInst::Create(HandleBB, BB, Cmp, FailBB);
llvm::ReturnInst::Create(Fn->getContext(), HandleBB);
Builder->SetInsertPoint(PreInsertInst);
return FailBB;
}
@ -182,8 +265,6 @@ namespace cling {
}
bool NullDerefProtectionTransformer::runOnFunction(llvm::Function &F) {
llvm::IRBuilder<> TheBuilder(F.getContext());
Builder = &TheBuilder;
std::vector<llvm::Instruction*> WorkList;
for (llvm::inst_iterator i = inst_begin(F), e = inst_end(F); i != e; ++i) {
@ -195,6 +276,7 @@ namespace cling {
for (std::vector<llvm::Instruction*>::iterator i = WorkList.begin(),
e = WorkList.end(); i != e; ++i) {
Inst = *i;
Builder->SetInsertPoint(Inst);
llvm::LoadInst* I = llvm::cast<llvm::LoadInst>(*i);
// Find all the instructions that uses the instruction I.
@ -243,4 +325,62 @@ namespace cling {
}
return true;
}
void NullDerefProtectionTransformer::instrumentCallInst(llvm::Instruction*
TheCall, const std::bitset<32>& ArgIndexs) {
llvm::Type* Int8PtrTy = llvm::Type::getInt8PtrTy(TheCall->getContext());
llvm::CallSite CS = TheCall;
for (int index = 0; index < 32; ++index) {
if (!ArgIndexs.test(index)) continue;
llvm::Value* Arg = CS.getArgument(index);
if (!Arg) continue;
llvm::Type* ArgTy = Arg->getType();
if (ArgTy != Int8PtrTy)
continue;
llvm::BasicBlock* OldBB = TheCall->getParent();
llvm::ICmpInst* Cmp
= new llvm::ICmpInst(TheCall, llvm::CmpInst::ICMP_EQ, Arg,
llvm::Constant::getNullValue(ArgTy), "");
llvm::Instruction* Inst = Builder->GetInsertPoint();
llvm::BasicBlock* NewBB = OldBB->splitBasicBlock(Inst);
OldBB->getTerminator()->eraseFromParent();
llvm::BranchInst::Create(getTrapBB(NewBB), NewBB, Cmp, OldBB);
}
}
void NullDerefProtectionTransformer::handleNonNullArgCall(llvm::Module& M,
const llvm::StringRef& name, const std::bitset<32>& ArgIndexs) {
// Get the function by the name.
llvm::Function* func = M.getFunction(name);
if (!func)
return;
// Find all the instructions that calls the function.
std::vector<llvm::Instruction*> WorkList;
for (llvm::Value::use_iterator I = func->use_begin(), E = func->use_end();
I != E; ++I) {
llvm::CallSite CS(*I);
if (!CS || CS.getCalledValue() != func)
continue;
WorkList.push_back(CS.getInstruction());
}
// There is no call instructions that call the function and return.
if (WorkList.empty())
return;
// Instrument all the call instructions that call the function.
for (std::vector<llvm::Instruction*>::iterator I = WorkList.begin(),
E = WorkList.end(); I != E; ++I) {
Inst = *I;
Builder->SetInsertPoint(Inst);
instrumentCallInst(Inst, ArgIndexs);
}
return;
}
} // end namespace cling

View File

@ -38,6 +38,11 @@ namespace cling {
llvm::BasicBlock* getTrapBB(llvm::BasicBlock* BB);
void instrumentInst(llvm::Instruction* Inst, llvm::Value* Arg);
bool runOnFunction(llvm::Function& F);
void instrumentCallInst(llvm::Instruction* TheCall,
const std::bitset<32>& ArgIndexs);
void handleNonNullArgCall(llvm::Module& M,
const llvm::StringRef& FName,
const std::bitset<32>& ArgIndexs);
public:
NullDerefProtectionTransformer(clang::Sema* S);