Skip to content

Commit

Permalink
arch-riscv: Add more constraints
Browse files Browse the repository at this point in the history
Change-Id: If31243ff7e2ce52c96a012153c3de97be83373c4
  • Loading branch information
FanYang98 committed Dec 26, 2022
1 parent 6237ec8 commit 3c5e649
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 9 deletions.
102 changes: 100 additions & 2 deletions src/arch/riscv/isa/formats/vector_arith.isa
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ let {{
}
'''

def VI_CHECK_ST_INDEX(elt):
def VI_CHECK_LD_INDEX(elt):
return '''
const uint64_t nf = machInst.nf + 1;
const int NVPR = 32;
Expand Down Expand Up @@ -587,8 +587,106 @@ let {{
}
''' %(elt)

def VI_LD():
def VI_CHECK_ST_INDEX(elt):
return '''
const uint64_t nf = machInst.nf + 1;
const int NVPR = 32;
const uint64_t elt = %s;
if(elt>gem5::RiscvISA::ELEN){
std::string error =
csprintf("Invalid elt and ELEN");
return std::make_shared<IllegalInstFault>(error, machInst);
}
const float vflmul = Vflmul(vtype_vlmul(machInst.vtype8));
const size_t vsew = sizeof(vu) * 8;
const float vemul = ((float)elt / vsew * vflmul);
if(vemul < 0.125 || vemul > 8){
std::string error =
csprintf("Invalid LMUL and SEW");
return std::make_shared<IllegalInstFault>(error, machInst);
}
if(!(is_aligned(machInst.vd, vflmul)&&
is_aligned(machInst.vs2, vemul))) {
std::string error =
csprintf("Unaligned Vd, Vs2 or Vs1 group");
return std::make_shared<IllegalInstFault>(error, machInst);
}
if(!(is_aligned(machInst.vd, vflmul)&&
is_aligned(machInst.vs2, vemul))) {
std::string error =
csprintf("Unaligned Vd, Vs2 or Vs1 group");
return std::make_shared<IllegalInstFault>(error, machInst);
}
const uint64_t flmul = vflmul < 1 ? 1 : vflmul;
if(!((nf*flmul)<=(NVPR/4) && (machInst.vd+nf*flmul)<=NVPR)) {
std::string error =
csprintf("Unaligned Vd group");
return std::make_shared<IllegalInstFault>(error, machInst);
}
''' %(elt)


def VI_CHECK_LOAD(elt, is_mask):
return '''
const uint64_t nf = machInst.nf + 1;
const float vflmul = Vflmul(vtype_vlmul(machInst.vtype8));
const size_t vsew = vtype_SEW(machInst.vtype8);
const int NVPR = 32;
bool is_mask = %s;
uint64_t elt = %s;
uint64_t veew = is_mask ? 1 : elt;
float vemul = is_mask ? 1 : ((float)veew / vsew * vflmul);
uint64_t emul = vemul < 1 ? 1 : vemul;
if(vemul < 0.125 || vemul > 8){
std::string error =
csprintf("Invalid LMUL and SEW");
return std::make_shared<IllegalInstFault>(error, machInst);
}
if(!(is_aligned(machInst.vd, vemul))) {
std::string error =
csprintf("Unaligned Vd group");
return std::make_shared<IllegalInstFault>(error, machInst);
}
if(!((nf*emul)<=(NVPR/4) && (machInst.vd+nf*emul)<=NVPR)) {
std::string error =
csprintf("Invalid Vd group");
return std::make_shared<IllegalInstFault>(error, machInst);
}
if(veew > gem5::RiscvISA::ELEN){
std::string error =
csprintf("Invalid veew and ELEN");
return std::make_shared<IllegalInstFault>(error, machInst);
}
''' %(is_mask, elt)

def VI_CHECK_STORE(elt, is_mask):
return VI_CHECK_LOAD(elt, is_mask)

def VI_CHECK_LD_WHOLE(elt):
return '''
uint64_t elt = %s;
const uint64_t nf = machInst.nf;
if(elt>gem5::RiscvISA::ELEN){
std::string error =
csprintf("Invalid elt and ELEN");
return std::make_shared<IllegalInstFault>(error, machInst);
}
if(!(is_aligned(machInst.vd, nf+1))) {
std::string error =
csprintf("Unaligned Vd group");
return std::make_shared<IllegalInstFault>(error, machInst);
}
''' %(elt)

def VI_CHECK_ST_WHOLE():
return '''
const uint64_t nf = machInst.nf;
if(!(is_aligned(machInst.vd, nf+1))) {
std::string error =
csprintf("Unaligned Vd group");
return std::make_shared<IllegalInstFault>(error, machInst);
}
'''

}};
Expand Down
43 changes: 39 additions & 4 deletions src/arch/riscv/isa/formats/vector_mem.isa
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,36 @@ def VMemBase(name, Name, ea_code, memacc_code, mem_flags,
mem_flags = makeList(mem_flags)
inst_flags = makeList(inst_flags)
inst_name, inst_suffix = name.split("_", maxsplit=1)
index_flag = ""
if inst_name.find("xei") != -1:
index_flag = re.findall(r'\d+', inst_name)[0]
vi_ld_index_flag = ""
vi_st_index_flag = ""
vi_ld_flag = ""
vi_st_flag = ""
vi_ld_whole_flag = ""
vi_st_whole_flag = ""
vi_ld_stride_flag = ""
vi_st_stride_flag = ""
is_mask = "false"
if inst_name.find("xei") != -1 and inst_name.find("vl") != -1:
vi_ld_index_flag = re.findall(r'\d+', inst_name)[0]
if inst_name.find("xei") != -1 and inst_name.find("vs") != -1:
vi_st_index_flag = re.findall(r'\d+', inst_name)[0]
if inst_name.find("vle") != -1:
vi_ld_flag = re.findall(r'\d+', inst_name)[0]
if inst_name.find("vse") != -1:
vi_st_flag = re.findall(r'\d+', inst_name)[0]
if inst_name.find("vlm") != -1 or inst_name.find("vsm") != -1:
vi_ld_flag = 8
vi_st_flag = 8
is_mask = "true"
if inst_name.find("re") != -1 and len(re.findall(r'\d+', inst_name))==2:
vi_ld_whole_flag = re.findall(r'\d+', inst_name)[1]
if inst_name.find("vlse") != -1:
vi_ld_stride_flag = re.findall(r'\d+', inst_name)[0]
if inst_name.find("vsse") != -1:
vi_st_stride_flag = re.findall(r'\d+', inst_name)[0]
iop = InstObjParams(name, Name, base_class,
{'ea_code': ea_code,
'memacc_code': memacc_code,
Expand All @@ -68,7 +95,15 @@ def VMemBase(name, Name, ea_code, memacc_code, mem_flags,
{'ea_code': ea_code,
'memacc_code': memacc_code,
'postacc_code': postacc_code,
'vi_check_st_index': VI_CHECK_ST_INDEX(index_flag)},
'vi_check_ld_index': VI_CHECK_LD_INDEX(vi_ld_index_flag),
'vi_check_st_index': VI_CHECK_ST_INDEX(vi_st_index_flag),
'vi_check_ld': VI_CHECK_LOAD(vi_ld_flag, is_mask),
'vi_check_st': VI_CHECK_STORE(vi_st_flag, is_mask),
'vi_check_ld_whole': VI_CHECK_LD_WHOLE(vi_ld_whole_flag),
'vi_check_st_whole': VI_CHECK_ST_WHOLE(),
'vi_check_ld_stride': VI_CHECK_LOAD(vi_ld_stride_flag, is_mask),
'vi_check_st_stride': VI_CHECK_STORE(vi_st_stride_flag, is_mask),
},
inst_flags)
if mem_flags:
Expand Down
13 changes: 10 additions & 3 deletions src/arch/riscv/isa/templates/vector_mem.isa
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ Fault
%(op_decl)s;
%(op_rd)s;
%(ea_code)s;
%(vi_check_ld)s;

RiscvISA::vreg_t tmp_v0;
uint8_t *v0;
Expand Down Expand Up @@ -262,6 +263,7 @@ Fault
v0 = tmp_v0.as<uint8_t>();
}

%(vi_check_st)s;
%(op_decl)s;
%(op_rd)s;
%(ea_code)s;
Expand Down Expand Up @@ -454,6 +456,7 @@ Fault
%(class_name)s::execute(ExecContext *xc, Trace::InstRecord *traceData) const
{
Addr EA;
%(vi_check_st_whole)s;
%(op_decl)s;
%(op_rd)s;
%(ea_code)s;
Expand Down Expand Up @@ -563,10 +566,12 @@ Fault
%(class_name)s::execute(ExecContext *xc, Trace::InstRecord *traceData) const
{
Addr EA;
%(vi_check_ld_whole)s;
%(op_decl)s;
%(op_rd)s;
%(ea_code)s;


Fault fault = readMemAtomicLE(xc, traceData, EA,
*(vreg_t::Container*)(&Mem), memAccessFlags);
if (fault != NoFault)
Expand Down Expand Up @@ -704,6 +709,7 @@ Fault
Fault fault = NoFault;
Addr EA;

%(vi_check_ld_stride)s;
%(op_decl)s;
%(op_rd)s;
constexpr uint8_t elem_size = sizeof(Vd[0]);
Expand Down Expand Up @@ -890,7 +896,7 @@ Fault
{
Fault fault = NoFault;
Addr EA;

%(vi_check_st_stride)s;
%(op_decl)s;
%(op_rd)s;
constexpr uint8_t elem_size = sizeof(Vs3[0]);
Expand Down Expand Up @@ -1059,7 +1065,7 @@ Fault
Fault fault = NoFault;
Addr EA;

%(vi_check_st_index)s;
%(vi_check_ld_index)s;
%(op_decl)s;
%(op_rd)s;
%(ea_code)s;
Expand Down Expand Up @@ -1256,6 +1262,7 @@ Fault
Fault fault = NoFault;
Addr EA;

%(vi_check_st_index)s;
%(op_decl)s;
%(op_rd)s;
%(ea_code)s;
Expand Down Expand Up @@ -1347,4 +1354,4 @@ switch(machInst.vtype8.vsew) {
default: GEM5_UNREACHABLE;
}

}};
}};

0 comments on commit 3c5e649

Please sign in to comment.