Protect the non-null taking functions like strcmp.
This commit is contained in:
parent
b23e62fe11
commit
10602d688f
@ -57,6 +57,57 @@ extern "C" {
|
||||
using namespace clang;
|
||||
|
||||
namespace cling {
|
||||
typedef std::map<llvm::StringRef, std::bitset<32> > nonnull_map_t;
|
||||
|
||||
// NonNullDeclFinder finds the function decls with nonnull attribute args.
|
||||
class NonNullDeclFinder
|
||||
: public RecursiveASTVisitor<NonNullDeclFinder> {
|
||||
private:
|
||||
Sema* m_Sema;
|
||||
llvm::SmallVector<llvm::StringRef, 8> NonNullDeclNames;
|
||||
nonnull_map_t NonNullArgIndexs;
|
||||
|
||||
public:
|
||||
NonNullDeclFinder(Sema* S) : m_Sema(S) {}
|
||||
const llvm::SmallVector<llvm::StringRef, 8>& getDeclNames() {
|
||||
return NonNullDeclNames;
|
||||
}
|
||||
|
||||
const nonnull_map_t& getArgIndexs() {
|
||||
return NonNullArgIndexs;
|
||||
}
|
||||
|
||||
// Deal with all the call expr in the transaction.
|
||||
bool VisitCallExpr(CallExpr* TheCall) {
|
||||
if (FunctionDecl* FDecl = TheCall->getDirectCallee()) {
|
||||
std::bitset<32> ArgIndexs;
|
||||
for (specific_attr_iterator<NonNullAttr>
|
||||
I = FDecl->specific_attr_begin<NonNullAttr>(),
|
||||
E = FDecl->specific_attr_end<NonNullAttr>(); I != E; ++I) {
|
||||
NonNullAttr *NonNull = *I;
|
||||
|
||||
// Store all the null attr argument's index into "ArgIndexs".
|
||||
for (NonNullAttr::args_iterator i = NonNull->args_begin(),
|
||||
e = NonNull->args_end(); i != e; ++i)
|
||||
ArgIndexs.set(*i);
|
||||
}
|
||||
|
||||
if (ArgIndexs.any()) {
|
||||
|
||||
// Get the function decl's name.
|
||||
llvm::StringRef FName = FDecl->getName();
|
||||
|
||||
// Store the function decl's name into the vector.
|
||||
NonNullDeclNames.push_back(FName);
|
||||
|
||||
// Store the function decl's name with its null attr args' indexs
|
||||
// into the map.
|
||||
NonNullArgIndexs.insert(std::make_pair(FName, ArgIndexs));
|
||||
}
|
||||
}
|
||||
return true; // returning false will abort the in-depth traversal.
|
||||
}
|
||||
};
|
||||
|
||||
NullDerefProtectionTransformer::NullDerefProtectionTransformer(Sema *S)
|
||||
: TransactionTransformer(S), FailBB(0), Builder(0), Inst(0) {}
|
||||
@ -65,7 +116,6 @@ namespace cling {
|
||||
{}
|
||||
|
||||
void NullDerefProtectionTransformer::Transform() {
|
||||
using namespace clang;
|
||||
FunctionDecl* FD = getTransaction()->getWrapperFD();
|
||||
if (!FD)
|
||||
return;
|
||||
@ -105,8 +155,39 @@ namespace cling {
|
||||
|
||||
// Find the function in the module.
|
||||
llvm::Function* F = getTransaction()->getModule()->getFunction(mangledName);
|
||||
if (F)
|
||||
runOnFunction(*F);
|
||||
if (!F) return;
|
||||
|
||||
llvm::IRBuilder<> TheBuilder(F->getContext());
|
||||
Builder = &TheBuilder;
|
||||
runOnFunction(*F);
|
||||
|
||||
NonNullDeclFinder Finder(m_Sema);
|
||||
|
||||
// Find all the function decls with null attribute arguments.
|
||||
for (size_t Idx = 0; Idx < getTransaction()->size(); ++Idx) {
|
||||
Transaction::DelayCallInfo I = (*getTransaction())[Idx];
|
||||
for (DeclGroupRef::const_iterator J = I.m_DGR.begin(),
|
||||
JE = I.m_DGR.end(); J != JE; ++J)
|
||||
if ((*J)->hasBody())
|
||||
Finder.TraverseStmt((*J)->getBody());
|
||||
}
|
||||
|
||||
const llvm::SmallVector<llvm::StringRef, 8>&
|
||||
FDeclNames = Finder.getDeclNames();
|
||||
if (FDeclNames.empty()) return;
|
||||
|
||||
llvm::Module* M = F->getParent();
|
||||
|
||||
for (llvm::SmallVector<llvm::StringRef, 8>::const_iterator
|
||||
i = FDeclNames.begin(), e = FDeclNames.end(); i != e; ++i) {
|
||||
const nonnull_map_t& ArgIndexs = Finder.getArgIndexs();
|
||||
nonnull_map_t::const_iterator it = ArgIndexs.find(*i);
|
||||
|
||||
if (it != ArgIndexs.end()) {
|
||||
const std::bitset<32>& ArgNums = it->second;
|
||||
handleNonNullArgCall(*M, *i, ArgNums);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::BasicBlock*
|
||||
@ -115,6 +196,7 @@ namespace cling {
|
||||
llvm::Module* Md = Fn->getParent();
|
||||
llvm::LLVMContext& ctx = Fn->getContext();
|
||||
|
||||
llvm::BasicBlock::iterator PreInsertInst = Builder->GetInsertPoint();
|
||||
FailBB = llvm::BasicBlock::Create(ctx, "FailBlock", Fn);
|
||||
Builder->SetInsertPoint(FailBB);
|
||||
|
||||
@ -157,6 +239,7 @@ namespace cling {
|
||||
llvm::BranchInst::Create(HandleBB, BB, Cmp, FailBB);
|
||||
|
||||
llvm::ReturnInst::Create(Fn->getContext(), HandleBB);
|
||||
Builder->SetInsertPoint(PreInsertInst);
|
||||
return FailBB;
|
||||
}
|
||||
|
||||
@ -182,8 +265,6 @@ namespace cling {
|
||||
}
|
||||
|
||||
bool NullDerefProtectionTransformer::runOnFunction(llvm::Function &F) {
|
||||
llvm::IRBuilder<> TheBuilder(F.getContext());
|
||||
Builder = &TheBuilder;
|
||||
|
||||
std::vector<llvm::Instruction*> WorkList;
|
||||
for (llvm::inst_iterator i = inst_begin(F), e = inst_end(F); i != e; ++i) {
|
||||
@ -195,6 +276,7 @@ namespace cling {
|
||||
for (std::vector<llvm::Instruction*>::iterator i = WorkList.begin(),
|
||||
e = WorkList.end(); i != e; ++i) {
|
||||
Inst = *i;
|
||||
Builder->SetInsertPoint(Inst);
|
||||
llvm::LoadInst* I = llvm::cast<llvm::LoadInst>(*i);
|
||||
|
||||
// Find all the instructions that uses the instruction I.
|
||||
@ -243,4 +325,62 @@ namespace cling {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void NullDerefProtectionTransformer::instrumentCallInst(llvm::Instruction*
|
||||
TheCall, const std::bitset<32>& ArgIndexs) {
|
||||
llvm::Type* Int8PtrTy = llvm::Type::getInt8PtrTy(TheCall->getContext());
|
||||
llvm::CallSite CS = TheCall;
|
||||
|
||||
for (int index = 0; index < 32; ++index) {
|
||||
if (!ArgIndexs.test(index)) continue;
|
||||
llvm::Value* Arg = CS.getArgument(index);
|
||||
if (!Arg) continue;
|
||||
llvm::Type* ArgTy = Arg->getType();
|
||||
if (ArgTy != Int8PtrTy)
|
||||
continue;
|
||||
|
||||
llvm::BasicBlock* OldBB = TheCall->getParent();
|
||||
llvm::ICmpInst* Cmp
|
||||
= new llvm::ICmpInst(TheCall, llvm::CmpInst::ICMP_EQ, Arg,
|
||||
llvm::Constant::getNullValue(ArgTy), "");
|
||||
|
||||
llvm::Instruction* Inst = Builder->GetInsertPoint();
|
||||
llvm::BasicBlock* NewBB = OldBB->splitBasicBlock(Inst);
|
||||
|
||||
OldBB->getTerminator()->eraseFromParent();
|
||||
llvm::BranchInst::Create(getTrapBB(NewBB), NewBB, Cmp, OldBB);
|
||||
}
|
||||
}
|
||||
|
||||
void NullDerefProtectionTransformer::handleNonNullArgCall(llvm::Module& M,
|
||||
const llvm::StringRef& name, const std::bitset<32>& ArgIndexs) {
|
||||
|
||||
// Get the function by the name.
|
||||
llvm::Function* func = M.getFunction(name);
|
||||
if (!func)
|
||||
return;
|
||||
|
||||
// Find all the instructions that calls the function.
|
||||
std::vector<llvm::Instruction*> WorkList;
|
||||
for (llvm::Value::use_iterator I = func->use_begin(), E = func->use_end();
|
||||
I != E; ++I) {
|
||||
llvm::CallSite CS(*I);
|
||||
if (!CS || CS.getCalledValue() != func)
|
||||
continue;
|
||||
WorkList.push_back(CS.getInstruction());
|
||||
}
|
||||
|
||||
// There is no call instructions that call the function and return.
|
||||
if (WorkList.empty())
|
||||
return;
|
||||
|
||||
// Instrument all the call instructions that call the function.
|
||||
for (std::vector<llvm::Instruction*>::iterator I = WorkList.begin(),
|
||||
E = WorkList.end(); I != E; ++I) {
|
||||
Inst = *I;
|
||||
Builder->SetInsertPoint(Inst);
|
||||
instrumentCallInst(Inst, ArgIndexs);
|
||||
}
|
||||
return;
|
||||
}
|
||||
} // end namespace cling
|
||||
|
@ -38,6 +38,11 @@ namespace cling {
|
||||
llvm::BasicBlock* getTrapBB(llvm::BasicBlock* BB);
|
||||
void instrumentInst(llvm::Instruction* Inst, llvm::Value* Arg);
|
||||
bool runOnFunction(llvm::Function& F);
|
||||
void instrumentCallInst(llvm::Instruction* TheCall,
|
||||
const std::bitset<32>& ArgIndexs);
|
||||
void handleNonNullArgCall(llvm::Module& M,
|
||||
const llvm::StringRef& FName,
|
||||
const std::bitset<32>& ArgIndexs);
|
||||
|
||||
public:
|
||||
NullDerefProtectionTransformer(clang::Sema* S);
|
||||
|
Loading…
x
Reference in New Issue
Block a user