Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 118 additions & 41 deletions streamcompiler/src/compiler/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ use core::panic;
use inkwell::{builder::Builder, context::Context, execution_engine::{ExecutionEngine, JitFunction}, module::{Linkage, Module}, passes::PassBuilderOptions, targets::{CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine}, types::{BasicType, BasicTypeEnum}, values::{BasicMetadataValueEnum, FloatValue, FunctionValue, IntValue, VectorValue}, AddressSpace, OptimizationLevel};
use inkwell::attributes::AttributeLoc;
use inkwell::llvm_sys::core::LLVMGetEnumAttributeKind;
use inkwell::types::AnyType;
use inkwell::types::{AnyType, VectorType};
use inkwell::values::AsValueRef;
use crate::{compiler::expression::{ExprCompiler, ScalarExprCompiler, VectorExprCompiler}, parser::{Clause, ClauseType, Expr}};

pub type StreamCompilerProgramType = unsafe extern "C" fn(*const f64, i32) -> ();
pub type NumgrepProgramType = unsafe extern "C" fn(*const f64, *const bool, i32) -> ();
pub type NumgrepProgramType = unsafe extern "C" fn(*const f64, *const u8, i32) -> ();

static mut ID: u64 = 0;

Expand Down Expand Up @@ -48,7 +48,7 @@ impl JittedStreamCompilerProgram<'_> {
}

impl JittedNumgrepProgram<'_> {
pub unsafe fn call(&self, input: *const f64, filter: *const bool, len: i32) {
pub unsafe fn call(&self, input: *const f64, filter: *const u8, len: i32) {
self.function.call(input, filter, len);
}
}
Expand Down Expand Up @@ -119,14 +119,14 @@ impl<'ctx> CodeGen<'ctx> {

pub fn compile_numgrep(&'ctx self, program: &'ctx[Clause]) -> JittedNumgrepProgram<'ctx>
{
let compiled_program = self.compile_numgrep_program::<1>(program);
JittedNumgrepProgram { vector_width: Some(1), function: compiled_program }
let compiled_program = self.compile_numgrep_program::<VEC_WIDTH>(program);
JittedNumgrepProgram { vector_width: Some(VEC_WIDTH), function: compiled_program }
}

fn compile_stream_compiler_program<const VEC_WIDTH: u32>(&'ctx self, program: &'ctx[Clause]) -> JitFunction<'ctx, StreamCompilerProgramType>
{
if VEC_WIDTH != 1 && VEC_WIDTH != 4 && VEC_WIDTH != 8 {
panic!("Unsupported vector width: {}. Only 1 (scalar), 4 and 8 are supported.", VEC_WIDTH);
if VEC_WIDTH != 1 && VEC_WIDTH != 8 {
panic!("Unsupported vector width: {}. Only 1 (scalar), and 8 are supported.", VEC_WIDTH);
}

struct CompiledClause<'a> {
Expand Down Expand Up @@ -327,8 +327,8 @@ impl<'ctx> CodeGen<'ctx> {
}

fn compile_numgrep_program<const VEC_WIDTH: u32>(&'ctx self, program: &'ctx[Clause]) -> JitFunction<'ctx, NumgrepProgramType> {
if VEC_WIDTH != 1 {
panic!("Unsupported vector width: {}. Only 1 (scalar) is supported.", VEC_WIDTH);
if VEC_WIDTH != 1 && VEC_WIDTH != 4 && VEC_WIDTH != 8 {
panic!("Unsupported vector width: {}. Only 1 (scalar), 4, or 8 is supported.", VEC_WIDTH);
}

struct CompiledClause<'a> {
Expand Down Expand Up @@ -361,12 +361,18 @@ impl<'ctx> CodeGen<'ctx> {
false
);

let (fn_name, function) = if VEC_WIDTH == 1 {
let (input_pointee_type, bool_type, fn_name, function) = if VEC_WIDTH == 1 {
let input_pointee_type = self.context.f64_type().as_basic_type_enum();
let bool_type = self.context.bool_type().as_basic_type_enum();
let fn_name = format!("numgrep_program_scalar_{}", get_id());
let function = self.module.add_function(&fn_name, fn_type, Some(Linkage::External));
(fn_name, function)
(input_pointee_type, bool_type, fn_name, function)
} else {
panic!("numgrep program compilation for vector width > 1 is not implemented yet");
let input_pointee_type = self.context.f64_type().vec_type(VEC_WIDTH).as_basic_type_enum();
let bool_type = self.context.bool_type().vec_type(VEC_WIDTH).as_basic_type_enum();
let fn_name = format!("numgrep_program_vec{}_{}", VEC_WIDTH, get_id());
let function = self.module.add_function(&fn_name, fn_type, Some(Linkage::External));
(input_pointee_type, bool_type, fn_name, function)
};

let entry = self.context.append_basic_block(function, "entry");
Expand All @@ -381,7 +387,12 @@ impl<'ctx> CodeGen<'ctx> {
let loop_end_bb = self.context.append_basic_block(function, "loop_end");
let exit_bb = self.context.append_basic_block(function, "exit");

let should_include = self.builder.build_alloca(self.context.bool_type(), "should_filter").expect("Failed to allocate should_filter variable");
let should_include = if VEC_WIDTH == 1 {
self.builder.build_alloca(self.context.bool_type(), "should_filter").expect("Failed to allocate should_filter variable")
} else {
self.builder.build_alloca(self.context.bool_type().vec_type(VEC_WIDTH), "should_filter").expect("Failed to allocate should_filter variable")
};

self.builder.build_unconditional_branch(loop_start_bb).expect("Failed to build unconditional branch to loop");
self.builder.position_at_end(loop_start_bb);

Expand All @@ -390,7 +401,7 @@ impl<'ctx> CodeGen<'ctx> {
(&self.context.i32_type().const_zero(), entry),
(&self.builder.build_int_add(
loop_index.as_basic_value().into_int_value(),
self.context.i32_type().const_int(VEC_WIDTH as u64, false),
self.context.i32_type().const_int(1, false),
"next_index"
).expect("Could not build increment"), loop_end_bb),
]);
Expand All @@ -399,7 +410,11 @@ impl<'ctx> CodeGen<'ctx> {
self.builder.build_int_compare(
inkwell::IntPredicate::ULT,
loop_index.as_basic_value().into_int_value(),
input_len,
self.builder.build_int_unsigned_div(
input_len,
self.context.i32_type().const_int(VEC_WIDTH as u64, false),
"loop_condition_div"
).expect("Could not build loop condition division"),
"loop_condition"
).expect("Could not build loop condition"),
loop_body_bb,
Expand All @@ -409,10 +424,10 @@ impl<'ctx> CodeGen<'ctx> {
self.builder.position_at_end(loop_body_bb);

let next_input = self.builder.build_load(
self.context.f64_type(),
input_pointee_type,
unsafe {
self.builder.build_gep(
self.context.f64_type(),
input_pointee_type,
input_ptr,
&[loop_index.as_basic_value().into_int_value()],
"input_ptr"
Expand All @@ -421,7 +436,21 @@ impl<'ctx> CodeGen<'ctx> {
"next_input"
).expect("Failed to load next input");

self.builder.build_store(should_include, self.context.bool_type().const_all_ones()).expect("Failed to store should_filter variable");
match VEC_WIDTH {
1 => {
self.builder.build_store(should_include, self.context.bool_type().const_all_ones()).expect("Failed to store should_include");
},
_ => {
// Yes I hate this too
let values = (0..VEC_WIDTH)
.map(|_| self.context.bool_type().const_all_ones())
.collect::<Vec<_>>();
self.builder.build_store(
should_include,
VectorType::const_vector(&values)
).expect("Failed to store should_include");
},
}

for clause in compiled_clauses {
let clause_entry_bb = self.context.append_basic_block(function, "clause_entry");
Expand All @@ -437,30 +466,64 @@ impl<'ctx> CodeGen<'ctx> {
.expect_left("Could not get result of clause call");

// This looks wasteful, but the hope is that a) we'll have only 1 clause, and b) it will be easier to vectorize
let current_should_include= self.builder.build_load(self.context.bool_type(), should_include, "current_should_include").expect("Failed to load should_include variable");
let new_should_include= self.builder.build_and(
current_should_include.into_int_value(),
result.into_int_value(),
"new_should_exclude"
).expect("Failed to build AND for should_filter");

self.builder.build_store(should_include, new_should_include).expect("Failed to store clause result");
let current_should_include= self.builder.build_load(bool_type, should_include, "current_should_include").expect("Failed to load should_include variable");

if VEC_WIDTH == 1 {
let new_should_include= self.builder.build_and(
current_should_include.into_int_value(),
result.into_int_value(),
"new_should_exclude"
).expect("Failed to build AND for should_filter");

self.builder.build_store(should_include, new_should_include).expect("Failed to store clause result");
} else {
let new_should_include= self.builder.build_and(
current_should_include.into_vector_value(),
result.into_vector_value(),
"new_should_exclude"
).expect("Failed to build AND for should_filter");

self.builder.build_store(should_include, new_should_include).expect("Failed to store clause result");
}
}

self.builder.build_unconditional_branch(loop_end_bb).expect("Failed to build unconditional branch to loop end");
self.builder.position_at_end(loop_end_bb);

self.builder.build_store(
unsafe {
self.builder.build_gep(
self.context.bool_type(),
filter_ptr,
&[loop_index.as_basic_value().into_int_value()],
"filter_ptr"
).expect("Could not build GEP for filter")
},
self.builder.build_load(self.context.bool_type(), should_include, "final_should_include").expect("Failed to load should_include variable")
).expect("Failed to store final filter result");
if VEC_WIDTH == 1 {
self.builder.build_store(
unsafe {
self.builder.build_gep(
bool_type,
filter_ptr,
&[loop_index.as_basic_value().into_int_value()],
"filter_ptr"
).expect("Could not build GEP for filter")
},
self.builder.build_load(bool_type, should_include, "final_should_include").expect("Failed to load should_include variable")
).expect("Failed to store final filter result");
} else {
let final_should_include = self.builder.build_load(
bool_type,
should_include,
"final_should_include"
).expect("Failed to load should_include variable");

// note that we're storing an <VEC_WIDTH x i1> vector into the filter pointer, which means the bits are packed
// i.e. a <8 x i1> is a single byte
self.builder.build_store(
unsafe {
self.builder.build_gep(
bool_type,
filter_ptr,
&[loop_index.as_basic_value().into_int_value()],
"filter_ptr"
).expect("Could not build GEP for filter")
},
self.builder.build_load(bool_type, should_include, "final_should_include").expect("Failed to load should_include variable")
).expect("Failed to store final filter result");
}

self.builder.build_unconditional_branch(loop_start_bb).expect("Failed to build unconditional branch to loop start");

self.builder.position_at_end(exit_bb);
Expand All @@ -478,7 +541,7 @@ impl<'ctx> CodeGen<'ctx> {
];

self.module.run_passes(passes.join(",").as_str(), &self.get_machine(), PassBuilderOptions::create()).expect("Failed to run passes on module");
// self.dump_module();
self.dump_module();
unsafe { self.execution_engine.get_function(&fn_name).ok().unwrap() }
}

Expand Down Expand Up @@ -514,8 +577,22 @@ impl<'ctx> CodeGen<'ctx> {
match clause.clause_type {
ClauseType::Filter => {
match x.get_type() {
BasicTypeEnum::VectorType(_) => {
panic!("Vec filtering is not supported yet, please use scalar filtering instead");
BasicTypeEnum::VectorType(t) => {
if t.get_element_type().into_float_type() != self.context.f64_type() {
panic!("Expected vector of f64 type for input parameter, got {:?}", t);
}
if t.get_size() != VEC_WIDTH as u32 {
panic!("Expected vector of f64 type with width {}, got {:?}", VEC_WIDTH, t);
}

let condition = self.compile_expression_vec(&clause.expression, x.into_vector_value()).expect("Failed to compile expression");

let condition_as_bool = self.float_as_bool_vec(condition);
if filters_return_i1 {
self.builder.build_return(Some(&condition_as_bool)).expect("Failed to build return for filter clause");
} else {
panic!("Vec filtering is only supported for i1 return type, please use scalar filtering instead");
}
},
BasicTypeEnum::FloatType(t) => {
if t != self.context.f64_type() {
Expand Down Expand Up @@ -592,7 +669,7 @@ impl<'ctx> CodeGen<'ctx> {
self.expr_compiler.float_as_bool(self, value.into()).into()
}

fn float_as_bool_vec<'cg>(&'cg self, value: VectorValue<'cg>) -> IntValue<'cg> {
fn float_as_bool_vec<'cg>(&'cg self, value: VectorValue<'cg>) -> VectorValue<'cg> {
self.vector_expr_compiler.float_as_bool(self, value.into()).into()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ impl<const VEC_WIDTH: u32> VectorExprCompiler<VEC_WIDTH> {
VectorType::const_vector(&values)
}

fn const_vec<'a>(self: &'a Self, context: &'a Context, value: f64) -> VectorValue<'a> {
pub fn const_vec<'a>(self: &'a Self, context: &'a Context, value: f64) -> VectorValue<'a> {
let as_f64 = context.f64_type().const_float(value);
self.fill_vec(as_f64)
}
Expand Down
22 changes: 20 additions & 2 deletions streamcompiler/src/numgrep/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ impl<'a> Runner<'a> {

let jitted_program = codegen.compile_numgrep(program);

match jitted_program.vector_width {
Some(8) => (),
Some(_) => panic!("Vector width other than 8 is not supported for numgrep"),
None => (),
}

Runner {
jitted_program,
}
Expand All @@ -25,13 +31,25 @@ impl<'a> Runner<'a> {
.flat_map(|(_, floats)| floats.iter().cloned())
.collect::<Vec<f64>>();

let should_include = vec![false; input_floats.len()];
// TODO: Handle when input_floats is not a multiple of the vector width
let should_include = vec![0u8; input_floats.len() / self.jitted_program.vector_width.unwrap_or(1) as usize];
unsafe { self.jitted_program.call(input_floats.as_ptr(), should_include.as_ptr(), input_floats.len() as i32); }

println!("should_include: {:?}", should_include);

#[inline]
fn is_set(should_include: &[u8], index: usize, vec_width: usize) -> bool {
if vec_width != 1 {
should_include[index / vec_width] & (1u8 << (index % vec_width)) != 0
} else {
should_include[index] != 0
}
}

let mut output_index = 0;
for i in 0..input.len() {
for j in 0..input[i].1.len() {
if should_include[output_index + j] {
if is_set(&should_include, output_index + j, self.jitted_program.vector_width.unwrap_or(1) as usize) {
println!("{}", input[i].0);
break;
}
Expand Down