1 Star 0 Fork 0

陈柏居/archLab

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
brchPredict.cpp 18.12 KB
一键复制 编辑 原始数据 按行查看 历史
陈柏居 提交于 2022-11-24 11:02 . 自欺欺人
#include <iostream>
#include <fstream>
#include <cassert>
#include <stdarg.h>
#include <cstdlib>
#include <cstring>
#include "pin.H"
using namespace std;
typedef unsigned char UINT8;
typedef unsigned short UINT16;
typedef unsigned int UINT32;
typedef unsigned long int UINT64;
typedef unsigned __int128 UINT128;
ofstream OutFile;
// 将val截断, 使其宽度变成bits
#define truncate(val, bits) ((val) & ((1 << (bits)) - 1))
static UINT64 takenCorrect = 0;
static UINT64 takenIncorrect = 0;
static UINT64 notTakenCorrect = 0;
static UINT64 notTakenIncorrect = 0;
static UINT64 use0 = 0;
static UINT64 use1 = 0;
// 饱和计数器 (N < 64)
class SaturatingCnt
{
size_t m_wid;
UINT8 m_val;
const UINT8 m_init_val;
public:
SaturatingCnt(size_t width = 2) : m_init_val((1 << width) / 2)
{
m_wid = width;
m_val = m_init_val;
}
void increase() { if (m_val < (1 << m_wid) - 1) m_val++; }
void decrease() { if (m_val > 0) m_val--; }
void reset() { m_val = m_init_val; }
UINT8 getVal() { return m_val; }
bool isTaken() { return (m_val > (1 << m_wid)/2 - 1); }
};
// 移位寄存器 (N < 128)
class ShiftReg
{
size_t m_wid;
UINT128 m_val;
public:
ShiftReg(size_t width) : m_wid(width), m_val(0) {}
bool shiftIn(bool b)
{
bool ret = !!(m_val & (1 << (m_wid - 1)));
m_val <<= 1;
m_val |= b;
m_val &= (1 << m_wid) - 1;
return ret;
}
UINT128 getVal() { return m_val; }
UINT128 fold(size_t size) {
// 将移位寄存器的值折叠, 相当于另一个哈希函数
UINT128 ret = 0;
for (size_t i = 0; i < m_wid/size; i++) {
ret ^= m_val >> (size * i);
}
return truncate(ret, size);
}
};
// Hash functions
inline UINT128 f_xor(UINT128 a, UINT128 b) { return a ^ b; }
inline UINT128 f_xor1(UINT128 a, UINT128 b) { return ~a ^ ~b; }
inline UINT128 f_xnor(UINT128 a, UINT128 b) { return ~(a ^ ~b); }
#define TAG_SIZE 6
// Base class of all predictors
class BranchPredictor
{
public:
BranchPredictor() {}
virtual ~BranchPredictor() {}
virtual bool predict(ADDRINT addr) { return false; };
virtual void update(bool takenActually, bool takenPredicted, ADDRINT addr) {};
};
BranchPredictor* BP;
/* ===================================================================== */
/* BHT-based branch predictor */
/* ===================================================================== */
class BHTPredictor: public BranchPredictor
{
size_t m_entries_log;
SaturatingCnt* m_scnt; // BHT
allocator<SaturatingCnt> m_alloc;
public:
// Constructor
// param: entry_num_log: BHT行数的对数
// scnt_width: 饱和计数器的位数, 默认值为2
BHTPredictor(size_t entry_num_log, size_t scnt_width = 2)
{
m_entries_log = entry_num_log;
m_scnt = m_alloc.allocate(1 << entry_num_log); // Allocate memory for BHT
for (int i = 0; i < (1 << entry_num_log); i++)
m_alloc.construct(m_scnt + i, scnt_width); // Call constructor of SaturatingCnt
}
// Destructor
~BHTPredictor()
{
for (int i = 0; i < (1 << m_entries_log); i++)
m_alloc.destroy(m_scnt + i);
m_alloc.deallocate(m_scnt, 1 << m_entries_log);
}
BOOL predict(ADDRINT addr)
{
// TODO: Produce prediction according to BHT
return m_scnt[truncate(addr, m_entries_log)].isTaken();
}
void update(BOOL takenActually, BOOL takenPredicted, ADDRINT addr)
{
// TODO: Update BHT according to branch results and prediction
SaturatingCnt* scnt = &m_scnt[truncate(addr, m_entries_log)];
if(takenActually) {
if(scnt->getVal() == 1) scnt->increase();
scnt->increase();
} else {
if(scnt->getVal() == 2) scnt->decrease();
scnt->decrease();
}
}
};
/* ===================================================================== */
/* Global-history-based branch predictor */
/* ===================================================================== */
template<UINT128 (*hash)(UINT128 addr, UINT128 history)>
class GlobalHistoryPredictor: public BranchPredictor
{
ShiftReg* m_ghr; // GHR
SaturatingCnt* m_scnt; // PHT中的分支历史字段
size_t m_entries_log; // PHT行数的对数
allocator<SaturatingCnt> m_alloc;
public:
// Constructor
// param: ghr_width: Width of GHR
// entry_num_log: PHT表行数的对数
// scnt_width: 饱和计数器的位数, 默认值为2
GlobalHistoryPredictor(size_t ghr_width, size_t entry_num_log, size_t scnt_width = 2)
{
// TODO:
m_ghr = new ShiftReg(ghr_width);
// 下面和BHT一样
m_entries_log = entry_num_log;
m_scnt = m_alloc.allocate(1 << entry_num_log); // Allocate memory for BHT
for (int i = 0; i < (1 << entry_num_log); i++)
m_alloc.construct(m_scnt + i, scnt_width); // Call constructor of SaturatingCnt
}
// Destructor
~GlobalHistoryPredictor()
{
// TODO
delete m_ghr;
for (int i = 0; i < (1 << m_entries_log); i++)
m_alloc.destroy(m_scnt + i);
m_alloc.deallocate(m_scnt, 1 << m_entries_log);
}
// Only for TAGE: return a tag according to the specificed address
UINT128 get_tag(ADDRINT addr)
{
// TODO
return truncate(hash(m_ghr->fold(TAG_SIZE), addr), TAG_SIZE);
}
// Only for TAGE: return GHR's value
UINT128 get_ghr()
{
// TODO
return m_ghr->getVal();
}
// Only for TAGE: reset a saturating counter to default value (which is weak taken)
void reset_ctr(ADDRINT addr)
{
// TODO
m_scnt[truncate(hash(addr, get_ghr()), m_entries_log)].reset();
}
bool predict(ADDRINT addr)
{
// TODO: Produce prediction according to GHR and PHT
return m_scnt[truncate(hash(addr, get_ghr()), m_entries_log)].isTaken();
}
void update(bool takenActually, bool takenPredicted, ADDRINT addr)
{
// TODO: Update GHR and PHT according to branch results and prediction
SaturatingCnt* scnt = &m_scnt[truncate(hash(addr, get_ghr()), m_entries_log)];
if(takenActually) {
if(scnt->getVal() == 1) scnt->increase();
scnt->increase();
} else {
if(scnt->getVal() == 2) scnt->decrease();
scnt->decrease();
}
m_ghr->shiftIn(takenActually);
}
};
/* ===================================================================== */
/* Tournament predictor: Select output by global/local selection history */
/* ===================================================================== */
class TournamentPredictor: public BranchPredictor
{
BranchPredictor* m_BPs[2]; // Sub-predictors
SaturatingCnt* m_gshr; // Global select-history register
bool prediction[2]; // 各预测器的预测结果
bool predictCorrect[2]; // 各预测器是否预测成功
public:
TournamentPredictor(BranchPredictor* BP0, BranchPredictor* BP1, size_t gshr_width = 2)
{
// TODO
m_BPs[0] = BP0;
m_BPs[1] = BP1;
m_gshr = new SaturatingCnt(gshr_width);
}
~TournamentPredictor()
{
// TODO
delete m_BPs[0];
delete m_BPs[1];
delete m_gshr;
}
// TODO
bool predict(ADDRINT addr)
{
// m_gshr->isTaken()? use1++: use0++;
for (size_t i = 0; i < 2; i++)
prediction[i] = m_BPs[i]->predict(addr);
return prediction[m_gshr->isTaken()];
}
void update(bool takenActually, bool takenPredicted, ADDRINT addr)
{
for (size_t i = 0; i < 2; i++) {
predictCorrect[i] = prediction[i] == takenActually; // 预测是否成功
m_BPs[i]->update(takenActually, predictCorrect[i], addr);
}
if(predictCorrect[1] & !predictCorrect[0]) {
if(m_gshr->getVal() == 1) m_gshr->increase();
m_gshr->increase();
} else if(predictCorrect[0] & !predictCorrect[1]) {
if(m_gshr->getVal() == 2) m_gshr->decrease();
m_gshr->decrease();
}
}
};
/* ===================================================================== */
/* TArget GEometric history length Predictor */
/* ===================================================================== */
template<UINT128 (*hash1)(UINT128 pc, UINT128 ghr), UINT128 (*hash2)(UINT128 pc, UINT128 ghr)>
class TAGEPredictor: public BranchPredictor
{
const size_t m_tnum; // 子预测器个数 (T[0 : m_tnum - 1])
const size_t m_entries_log; // 子预测器T[1 : m_tnum - 1]的PHT行数的对数
// BranchPredictor** m_T; // 子预测器指针数组
BHTPredictor* m_bht; // BHT子预测器
GlobalHistoryPredictor<hash1>** m_T; // 子预测器指针数组
bool* m_T_pred; // 用于存储各子预测的预测值
UINT8** m_useful; // usefulness matrix
UINT8** m_tags; // tags matrix
int provider_indx; // Provider's index of m_T
int altpred_indx; // Alternate provider's index of m_T
const size_t m_rst_period; // Reset period of usefulness
size_t m_rst_cnt; // Reset counter
public:
// Constructor
// param: tnum: The number of sub-predictors
// T0_entry_num_log: 子预测器T0的BHT行数的对数
// T1ghr_len: 子预测器T1的GHR位宽
// alpha: 各子预测器T[1 : m_tnum - 1]的GHR几何倍数关系
// Tn_entry_num_log: 各子预测器T[1 : m_tnum - 1]的PHT行数的对数
// scnt_width: Width of saturating counter (3 by default)
// rst_period: Reset period of usefulness
TAGEPredictor(size_t tnum, size_t T0_entry_num_log, size_t T1ghr_len, float alpha, size_t Tn_entry_num_log, size_t scnt_width = 3, size_t rst_period = 256*1024)
: m_tnum(tnum), m_entries_log(Tn_entry_num_log), m_rst_period(rst_period), m_rst_cnt(0)
{
m_T = new GlobalHistoryPredictor<hash1>* [m_tnum];
m_T_pred = new bool [m_tnum];
m_useful = new UINT8* [m_tnum];
m_tags = new UINT8* [m_tnum];
m_bht = new BHTPredictor(T0_entry_num_log);
size_t ghr_size = T1ghr_len;
for (size_t i = 1; i < m_tnum; i++)
{
m_T[i] = new GlobalHistoryPredictor<hash1>(ghr_size, m_entries_log, scnt_width);
ghr_size = (size_t)(ghr_size * alpha);
m_useful[i] = new UINT8 [1 << m_entries_log];
memset(m_useful[i], 0, sizeof(UINT8)*(1 << m_entries_log));
// 与useful一样, 建立tag二维数组
m_tags[i] = new UINT8 [1 << m_entries_log];
memset(m_tags[i], 0, sizeof(UINT8)*(1 << m_entries_log));
}
}
~TAGEPredictor()
{
for (size_t i = 0; i < m_tnum; i++) delete m_T[i];
for (size_t i = 0; i < m_tnum; i++) delete[] m_useful[i];
for (size_t i = 0; i < m_tnum; i++) delete[] m_tags[i];
delete[] m_T;
delete[] m_T_pred;
delete[] m_useful;
delete[] m_tags;
}
UINT128 get_index(size_t i, UINT128 addr) {
return truncate(hash1(addr, m_T[i]->get_ghr()), m_entries_log);
}
bool predict(ADDRINT addr)
{
// TODO
// 先初始化BHT的预测结果
provider_indx = 0;
altpred_indx = 0;
m_T_pred[0] = m_bht->predict(addr);
// 选择子预测器的结果
for (size_t i = 1; i < m_tnum; i++) {
if(m_tags[i][get_index(i, addr)] != m_T[i]->get_tag(addr)) {
// use0++;
continue;
}
// else{use1++;}
m_T_pred[i] = m_T[i]->predict(addr);
altpred_indx = provider_indx;
provider_indx = i;
}
return m_T_pred[provider_indx];
}
void update(bool takenActually, bool takenPredicted, ADDRINT addr)
{
// TODO: Update provider itself
if(provider_indx == 0)
m_bht->update(takenActually, takenPredicted, addr);
else{
m_tags[provider_indx][get_index(provider_indx, addr)] = m_T[provider_indx]->get_tag(addr);
m_T[provider_indx]->update(takenActually, takenPredicted, addr);
}
// TODO: Update usefulness
if(m_T_pred[provider_indx] != m_T_pred[altpred_indx]) {
UINT8* tmp = &m_useful[provider_indx][get_index(provider_indx, addr)];
if(takenPredicted) {
if (*tmp < 3) (*tmp)++;
} else {
if (*tmp > 0) (*tmp)--;
}
}
// TODO: Reset usefulness periodically
if(++m_rst_cnt >= m_rst_period) {
m_rst_cnt = 0;
for (size_t i = 1; i < m_tnum; i++)
memset(m_useful[i], 0, sizeof(UINT8)*(1 << m_entries_log));
}
// TODO: Entry replacement
if(!takenPredicted) {
for (size_t i = provider_indx + 1; i < m_tnum; i++) {
if(m_useful[i][get_index(i, addr)])
continue;
m_T[i]->reset_ctr(addr);
m_tags[i][get_index(i, addr)] = m_T[i]->get_tag(addr);
return;
}
for (size_t i = provider_indx + 1; i < m_tnum; i++)
m_useful[i][get_index(i, addr)] --;
}
}
};
// This function is called every time a control-flow instruction is encountered
void predictBranch(ADDRINT pc, BOOL direction)
{
BOOL prediction = BP->predict(pc);
BP->update(direction, prediction, pc);
if (prediction)
{
if (direction)
takenCorrect++;
else
takenIncorrect++;
}
else
{
if (direction)
notTakenIncorrect++;
else
notTakenCorrect++;
}
}
// Pin calls this function every time a new instruction is encountered
void Instruction(INS ins, void * v)
{
if (INS_IsControlFlow(ins) && INS_HasFallThrough(ins))
{
// Insert a call to the branch target
INS_InsertCall(ins, IPOINT_TAKEN_BRANCH, (AFUNPTR)predictBranch,
IARG_INST_PTR, IARG_BOOL, TRUE, IARG_END);
// Insert a call to the next instruction of a branch
INS_InsertCall(ins, IPOINT_AFTER, (AFUNPTR)predictBranch,
IARG_INST_PTR, IARG_BOOL, FALSE, IARG_END);
}
}
// This knob sets the output file name
KNOB<string> KnobOutputFile(KNOB_MODE_WRITEONCE, "pintool", "o", "brchPredict.txt", "specify the output file name");
// This function is called when the application exits
VOID Fini(int, VOID * v)
{
double precision = 100 * double(takenCorrect + notTakenCorrect) / (takenCorrect + notTakenCorrect + takenIncorrect + notTakenIncorrect);
cout << "takenCorrect: " << takenCorrect << endl
<< "takenIncorrect: " << takenIncorrect << endl
<< "notTakenCorrect: " << notTakenCorrect << endl
<< "nnotTakenIncorrect: " << notTakenIncorrect << endl
<< "Precision: " << precision << endl
<< "use0: " << use0 << endl
<< "use1: " << use1 << endl;
OutFile.setf(ios::showbase);
OutFile << "takenCorrect: " << takenCorrect << endl
<< "takenIncorrect: " << takenIncorrect << endl
<< "notTakenCorrect: " << notTakenCorrect << endl
<< "nnotTakenIncorrect: " << notTakenIncorrect << endl
<< "Precision: " << precision << endl;
OutFile.close();
delete BP;
}
/* ===================================================================== */
/* Print Help Message */
/* ===================================================================== */
INT32 Usage()
{
cerr << "This tool counts the number of dynamic instructions executed" << endl;
cerr << endl << KNOB_BASE::StringKnobSummary() << endl;
return -1;
}
/* ===================================================================== */
/* Main */
/* ===================================================================== */
/* argc, argv are the entire command line: pin -t <toolname> -- ... */
/* ===================================================================== */
int main(int argc, char * argv[])
{
// TODO: New your Predictor below.
// BP = new BHTPredictor(15);
// BP = new GlobalHistoryPredictor<f_xor>(10, 15);
// BP = new TournamentPredictor(new GlobalHistoryPredictor<f_xor>(9, 14), new BHTPredictor(14));
BP = new TAGEPredictor<f_xor, f_xor1>(7, 11, 10, 1.5, 12);
// Initialize pin
if (PIN_Init(argc, argv)) return Usage();
OutFile.open(KnobOutputFile.Value().c_str());
// Register Instruction to be called to instrument instructions
INS_AddInstrumentFunction(Instruction, 0);
// Register Fini to be called when the application exits
PIN_AddFiniFunction(Fini, 0);
// Start the program, never returns
PIN_StartProgram();
return 0;
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/chen-boju/arch-lab.git
[email protected]:chen-boju/arch-lab.git
chen-boju
arch-lab
archLab
master

搜索帮助