class VInRegen : public RegenType {
public:
VInRegen(uint64_t task_num, uint subpartition) {
task_num_ = task_num; subparition_ = subpartition;
}
CDErrType Regenerate(void* data_ptr, uint64_t len) {
}
protected:
uint64_t task_num_;
uint subpartition_;
};
void SpMVRecurse(const SparseMatrix* matrix,
const HierVector* v_in,
HierVector* v_out,
const CDHandle* current_cd,
uint num_tasks
) {
if (num_tasks > RECURSE_DEGREE) {
uint tasks_per_child = num_tasks/RECURSE_DEGREE;
for (int child=0; child < RECURSE_DEGREE; child++) {
CDHandle* child_cd;
child_cd = current_cd->CreateAndBegin(child, tasks_per_child);
CDEvent preserve_event;
child_cd->Preserve(preserve_event,
matrix->Subpartition(child),
matrix->SubpartitionLen(child),
"Matrix",
"Matrix", matrix->PartitionOffset(),
0,
);
VInRegen v_in_regen(ParRuntime::MyTaskNum(), child);
child_cd->Preserve(preserve_event,
v_in->Subpartition(child),
v_in->SubpartitionLen(child),
"vIn",
"vIn", v_in->PartitionOffset(),
&v_in_regen
);
SpMVRecurse(matrix->Subpartition(child),
v_in->Subpartition(child),
v_out->Subpartition(child),
child_cd, tasks_per_child);
preserve_event->Wait();
child_cd->Complete();
child_cd->Destroy();
}
}
else {
for (int child=0; child < num_tasks; child++) {
CDHandle* child_cd;
CDHandle* child_cd = current_cd->Create();
child_cd->Begin();
CDEvent preserve_event;
child_cd->Preserve(preserve_event,
matrix->Partition(),
matrix->PartitionLen(),
"Matrix",
"Matrix", matrix->PartitionOffset(),
0,
);
VInRegen v_in_regen(ParRuntime::MyTaskNum(), child);
child_cd->Preserve(preserve_event,
v_in->Partition(),
v_in->SubpartitionLen(),
"vIn",
"vIn", v_in->PartitionOffset(),
&v_in_regen
);
SpMVLeaf(matrix->Subpartition(child),
v_in->Subpartition(child),
v_out->Subpartition(child),
child_cd, tasks_per_child);
child_cd->Complete();
child_cd->Destroy();
}
}
v_out->ReduceSubpartitions(num_tasks);
}
void SpMVLeaf(const SparseMatrix* matrix,
const HierVector* v_in,
HierVector* v_out,
const CDHandle* current_cd,
uint num_tasks
) {
for (uint row=0; row < matrix->NumRows(); row++) {
v_out[row] = 0.0;
for (unit col = matrix->RowStart[row];
col < matrix->RowStart[row+1];
col++) {
uint prev_idx = 0;
uint idx = matrix->Index[col];
v_out[row] += matrix->NonZero[col]*v_in[idx];
CDAssert(idx >= prev_idx);
prev_idx = idx;
}
}
}