Mercurial > hg > octave-thorsten
changeset 15016:005cb78e1dd1
Split pt-jit into multiple files.
* src/Makefile: Add jit-util.h, jit-typeinfo.h, jit-ir.h, jit-util.cc,
jit-typeinfo.cc, and jit-ir.cc.
* src/jit-ir.cc: New file.
* src/jit-ir.h: New file.
* src/jit-typeinfo.cc: New file.
* src/jit-typeinfo.h: New file.
* src/jit-util.h: New file.
* src/jit-util.cc: New file.
* src/pt-jit.cc: (jit_fail_exception): Move to jit-ir.h.
(fail): Removed function.
(jit_print, jit_use, jit_value, jit_instruction, jit_block, jit_phi_incomming,
jit_phi, jit_terminator, jit_call): Moved to jit-ir.cc.
(octave_jit_print_any, octave_jit_print_double, octave_jit_binary_any_any,
octave_jit_compute_nelem, octave_jit_release_any, octave_jit_release_matrix,
octave_jit_grab_any, octave_jit_grab_matrix, octave_jit_cast_any_matrix,
octave_jit_cast_matrix_any, octave_jit_cast_scalar_any,
octave_jit_cast_any_scalar, octave_jit_cast_complex_any,
octave_jit_cast_any_complex, octave_jit_gripe_nan_to_logical_conversion,
octave_jit_ginvalid_index, octave_jit_gindex_range,
octave_jit_paren_subsasgn_impl, octave_jit_paren_subsasgn_matrix_range,
octave_jit_complex_div, octave_jit_pow_scalar_scalar,
octave_jit_pow_complex_complex, octave_jit_pow_scalar_scalar,
octave_jit_pow_complex_scalar, octave_jit_pow_scalar_scalar,
octave_jit_pow_scalar_complex, octave_jit_pow_scalar_scalar,
octave_jit_print_matrix, octave_jit_call, jit_type, jit_function,
jit_operation, jit_typeinfo): Moved to jit-typeinfo.cc
* src/pt-jit.h (jit_print, jit_use, jit_value, jit_instruction, jit_block,
jit_phi_incomming, jit_phi, jit_terminator, jit_call): Moved to jit-ir.h.
(jit_internal_list, jit_internal_node, jit_range, jit_array): Moved to
jit-util.h.
(jit_type, jit_function, jit_operation, jit_typeinfo): Moved to jit-typeinfo.h
author | Max Brister <max@2bass.com> |
---|---|
date | Wed, 25 Jul 2012 21:12:47 -0500 |
parents | fee211d42c5c |
children | dd4ad69e4ab9 |
files | src/Makefile.am src/jit-ir.cc src/jit-ir.h src/jit-typeinfo.cc src/jit-typeinfo.h src/jit-util.cc src/jit-util.h src/pt-jit.cc src/pt-jit.h |
diffstat | 9 files changed, 4586 insertions(+), 4338 deletions(-) [+] |
line wrap: on
line diff
--- a/src/Makefile.am +++ b/src/Makefile.am @@ -220,7 +220,6 @@ pt-fcn-handle.h \ pt-id.h \ pt-idx.h \ - pt-jit.h \ pt-jump.h \ pt-loop.h \ pt-mat.h \ @@ -232,6 +231,12 @@ pt-walk.h \ pt.h +JIT_INCLUDES = \ + jit-util.h \ + jit-typeinfo.h \ + jit-ir.h \ + pt-jit.h + octinclude_HEADERS = \ Cell.h \ builtins.h \ @@ -310,7 +315,8 @@ zfstream.h \ $(OV_INCLUDES) \ $(OV_SPARSE_INCLUDES) \ - $(PT_INCLUDES) + $(PT_INCLUDES) \ + $(JIT_INCLUDES) nodist_octinclude_HEADERS = \ defaults.h \ @@ -393,7 +399,6 @@ pt-fcn-handle.cc \ pt-id.cc \ pt-idx.cc \ - pt-jit.cc \ pt-jump.cc \ pt-loop.cc \ pt-mat.cc \ @@ -404,6 +409,12 @@ pt-unop.cc \ pt.cc +JIT_SRC = \ + jit-util.cc \ + jit-typeinfo.cc \ + jit-ir.cc \ + pt-jit.cc + DIST_SRC = \ Cell.cc \ bitfcns.cc \ @@ -476,7 +487,8 @@ xpow.cc \ zfstream.cc \ $(OV_SRC) \ - $(PT_SRC) + $(PT_SRC) \ + $(JIT_SRC) include DLD-FUNCTIONS/module.mk
new file mode 100644 --- /dev/null +++ b/src/jit-ir.cc @@ -0,0 +1,601 @@ +/* + +Copyright (C) 2012 Max Brister <max@2bass.com> + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +<http://www.gnu.org/licenses/>. + +*/ + +// defines required by llvm +#define __STDC_LIMIT_MACROS +#define __STDC_CONSTANT_MACROS + +#ifdef HAVE_CONFIG_H +#include <config.h> +#endif + +#ifdef HAVE_LLVM + +#include "jit-ir.h" + +#include <llvm/BasicBlock.h> +#include <llvm/Instructions.h> + +#include "error.h" +#include "pt-jit.h" + +// -------------------- jit_use -------------------- +jit_block * +jit_use::user_parent (void) const +{ + return muser->parent (); +} + +// -------------------- jit_value -------------------- +jit_value::~jit_value (void) +{} + +jit_block * +jit_value::first_use_block (void) +{ + jit_use *use = first_use (); + while (use) + { + if (! isa<jit_error_check> (use->user ())) + return use->user_parent (); + + use = use->next (); + } + + return 0; +} + +void +jit_value::replace_with (jit_value *value) +{ + while (first_use ()) + { + jit_instruction *user = first_use ()->user (); + size_t idx = first_use ()->index (); + user->stash_argument (idx, value); + } +} + +#define JIT_METH(clname) \ + void \ + jit_ ## clname::accept (jit_ir_walker& walker) \ + { \ + walker.visit (*this); \ + } + +JIT_VISIT_IR_NOTEMPLATE +#undef JIT_METH + +std::ostream& +operator<< (std::ostream& os, const jit_value& value) +{ + return value.short_print (os); +} + +std::ostream& +jit_print (std::ostream& os, jit_value *avalue) +{ + if (avalue) + return avalue->print (os); + return os << "NULL"; +} + +// -------------------- jit_instruction -------------------- +void +jit_instruction::remove (void) +{ + if (mparent) + mparent->remove (mlocation); + resize_arguments (0); +} + +llvm::BasicBlock * +jit_instruction::parent_llvm (void) const +{ + return mparent->to_llvm (); +} + +std::ostream& +jit_instruction::short_print (std::ostream& os) const +{ + if (type ()) + jit_print (os, type ()) << ": "; + return os << "#" << mid; +} + +void +jit_instruction::do_construct_ssa (size_t start, size_t end) +{ + for (size_t i = start; i < end; ++i) + { + jit_value *arg = argument (i); + jit_variable *var = dynamic_cast<jit_variable *> (arg); + if (var && var->has_top ()) + stash_argument (i, var->top ()); + } +} + +// -------------------- jit_block -------------------- +void +jit_block::replace_with (jit_value *value) +{ + assert (isa<jit_block> (value)); + jit_block *block = static_cast<jit_block *> (value); + + jit_value::replace_with (block); + + while (ILIST_T::first_use ()) + { + jit_phi_incomming *incomming = ILIST_T::first_use (); + incomming->stash_value (block); + } +} + +void +jit_block::replace_in_phi (jit_block *ablock, jit_block *with) +{ + jit_phi_incomming *node = ILIST_T::first_use (); + while (node) + { + jit_phi_incomming *prev = node; + node = node->next (); + + if (prev->user_parent () == ablock) + prev->stash_value (with); + } +} + +jit_block * +jit_block::maybe_merge () +{ + if (successor_count () == 1 && successor (0) != this + && (successor (0)->use_count () == 1 || instructions.size () == 1)) + { + jit_block *to_merge = successor (0); + merge (*to_merge); + return to_merge; + } + + return 0; +} + +void +jit_block::merge (jit_block& block) +{ + // the merge block will contain a new terminator + jit_terminator *old_term = terminator (); + if (old_term) + old_term->remove (); + + bool was_empty = end () == begin (); + iterator merge_begin = end (); + if (! was_empty) + --merge_begin; + + instructions.splice (end (), block.instructions); + if (was_empty) + merge_begin = begin (); + else + ++merge_begin; + + // now merge_begin points to the start of the new instructions, we must + // update their parent information + for (iterator iter = merge_begin; iter != end (); ++iter) + { + jit_instruction *instr = *iter; + instr->stash_parent (this, iter); + } + + block.replace_with (this); +} + +jit_instruction * +jit_block::prepend (jit_instruction *instr) +{ + instructions.push_front (instr); + instr->stash_parent (this, instructions.begin ()); + return instr; +} + +jit_instruction * +jit_block::prepend_after_phi (jit_instruction *instr) +{ + // FIXME: Make this O(1) + for (iterator iter = begin (); iter != end (); ++iter) + { + jit_instruction *temp = *iter; + if (! isa<jit_phi> (temp)) + { + insert_before (iter, instr); + return instr; + } + } + + return append (instr); +} + +void +jit_block::internal_append (jit_instruction *instr) +{ + instructions.push_back (instr); + instr->stash_parent (this, --instructions.end ()); +} + +jit_instruction * +jit_block::insert_before (iterator loc, jit_instruction *instr) +{ + iterator iloc = instructions.insert (loc, instr); + instr->stash_parent (this, iloc); + return instr; +} + +jit_instruction * +jit_block::insert_after (iterator loc, jit_instruction *instr) +{ + ++loc; + iterator iloc = instructions.insert (loc, instr); + instr->stash_parent (this, iloc); + return instr; +} + +jit_terminator * +jit_block::terminator (void) const +{ + assert (this); + if (instructions.empty ()) + return 0; + + jit_instruction *last = instructions.back (); + return dynamic_cast<jit_terminator *> (last); +} + +bool +jit_block::branch_alive (jit_block *asucc) const +{ + return terminator ()->alive (asucc); +} + +jit_block * +jit_block::successor (size_t i) const +{ + jit_terminator *term = terminator (); + return term->successor (i); +} + +size_t +jit_block::successor_count (void) const +{ + jit_terminator *term = terminator (); + return term ? term->successor_count () : 0; +} + +llvm::BasicBlock * +jit_block::to_llvm (void) const +{ + return llvm::cast<llvm::BasicBlock> (llvm_value); +} + +std::ostream& +jit_block::print_dom (std::ostream& os) const +{ + short_print (os); + os << ":\n"; + os << " mid: " << mid << std::endl; + os << " predecessors: "; + for (jit_use *use = first_use (); use; use = use->next ()) + os << *use->user_parent () << " "; + os << std::endl; + + os << " successors: "; + for (size_t i = 0; i < successor_count (); ++i) + os << *successor (i) << " "; + os << std::endl; + + os << " idom: "; + if (idom) + os << *idom; + else + os << "NULL"; + os << std::endl; + os << " df: "; + for (df_iterator iter = df_begin (); iter != df_end (); ++iter) + os << **iter << " "; + os << std::endl; + + os << " dom_succ: "; + for (size_t i = 0; i < dom_succ.size (); ++i) + os << *dom_succ[i] << " "; + + return os << std::endl; +} + +void +jit_block::compute_df (size_t avisit_count) +{ + if (visited (avisit_count)) + return; + + if (use_count () >= 2) + { + for (jit_use *use = first_use (); use; use = use->next ()) + { + jit_block *runner = use->user_parent (); + while (runner != idom) + { + runner->mdf.insert (this); + runner = runner->idom; + } + } + } + + for (size_t i = 0; i < successor_count (); ++i) + successor (i)->compute_df (avisit_count); +} + +bool +jit_block::update_idom (size_t avisit_count) +{ + if (visited (avisit_count) || ! use_count ()) + return false; + + bool changed = false; + for (jit_use *use = first_use (); use; use = use->next ()) + { + jit_block *pred = use->user_parent (); + changed = pred->update_idom (avisit_count) || changed; + } + + jit_use *use = first_use (); + jit_block *new_idom = use->user_parent (); + use = use->next (); + + for (; use; use = use->next ()) + { + jit_block *pred = use->user_parent (); + jit_block *pidom = pred->idom; + if (pidom) + new_idom = idom_intersect (pidom, new_idom); + } + + if (idom != new_idom) + { + idom = new_idom; + return true; + } + + return changed; +} + +void +jit_block::pop_all (void) +{ + for (iterator iter = begin (); iter != end (); ++iter) + { + jit_instruction *instr = *iter; + instr->pop_variable (); + } +} + +jit_block * +jit_block::maybe_split (jit_convert& convert, jit_block *asuccessor) +{ + if (successor_count () > 1) + { + jit_terminator *term = terminator (); + size_t idx = term->successor_index (asuccessor); + jit_block *split = convert.create<jit_block> ("phi_split", mvisit_count); + + // try to place splits where they make sense + if (id () < asuccessor->id ()) + convert.insert_before (asuccessor, split); + else + convert.insert_after (this, split); + + term->stash_argument (idx, split); + jit_branch *br = split->append (convert.create<jit_branch> (asuccessor)); + replace_in_phi (asuccessor, split); + + if (alive ()) + { + split->mark_alive (); + br->infer (); + } + + return split; + } + + return this; +} + +void +jit_block::create_dom_tree (size_t avisit_count) +{ + if (visited (avisit_count)) + return; + + if (idom != this) + idom->dom_succ.push_back (this); + + for (size_t i = 0; i < successor_count (); ++i) + successor (i)->create_dom_tree (avisit_count); +} + +jit_block * +jit_block::idom_intersect (jit_block *i, jit_block *j) +{ + while (i && j && i != j) + { + while (i && i->id () > j->id ()) + i = i->idom; + + while (i && j && j->id () > i->id ()) + j = j->idom; + } + + return i ? i : j; +} + +// -------------------- jit_phi_incomming -------------------- + +jit_block * +jit_phi_incomming::user_parent (void) const +{ return muser->parent (); } + +// -------------------- jit_phi -------------------- +bool +jit_phi::prune (void) +{ + jit_block *p = parent (); + size_t new_idx = 0; + jit_value *unique = argument (1); + + for (size_t i = 0; i < argument_count (); ++i) + { + jit_block *inc = incomming (i); + if (inc->branch_alive (p)) + { + if (unique != argument (i)) + unique = 0; + + if (new_idx != i) + { + stash_argument (new_idx, argument (i)); + mincomming[new_idx].stash_value (inc); + } + + ++new_idx; + } + } + + if (new_idx != argument_count ()) + { + resize_arguments (new_idx); + mincomming.resize (new_idx); + } + + assert (argument_count () > 0); + if (unique) + { + replace_with (unique); + return true; + } + + return false; +} + +bool +jit_phi::infer (void) +{ + jit_block *p = parent (); + if (! p->alive ()) + return false; + + jit_type *infered = 0; + for (size_t i = 0; i < argument_count (); ++i) + { + jit_block *inc = incomming (i); + if (inc->branch_alive (p)) + infered = jit_typeinfo::join (infered, argument_type (i)); + } + + if (infered != type ()) + { + stash_type (infered); + return true; + } + + return false; +} + +llvm::PHINode * +jit_phi::to_llvm (void) const +{ + return llvm::cast<llvm::PHINode> (jit_value::to_llvm ()); +} + +// -------------------- jit_terminator -------------------- +size_t +jit_terminator::successor_index (const jit_block *asuccessor) const +{ + size_t scount = successor_count (); + for (size_t i = 0; i < scount; ++i) + if (successor (i) == asuccessor) + return i; + + panic_impossible (); +} + +bool +jit_terminator::infer (void) +{ + if (! parent ()->alive ()) + return false; + + bool changed = false; + for (size_t i = 0; i < malive.size (); ++i) + if (! malive[i] && check_alive (i)) + { + changed = true; + malive[i] = true; + successor (i)->mark_alive (); + } + + return changed; +} + +llvm::TerminatorInst * +jit_terminator::to_llvm (void) const +{ + return llvm::cast<llvm::TerminatorInst> (jit_value::to_llvm ()); +} + +// -------------------- jit_call -------------------- +bool +jit_call::infer (void) +{ + // FIXME: explain algorithm + for (size_t i = 0; i < argument_count (); ++i) + { + already_infered[i] = argument_type (i); + if (! already_infered[i]) + return false; + } + + jit_type *infered = moperation.result (already_infered); + if (! infered && use_count ()) + { + std::stringstream ss; + ss << "Missing overload in type inference for "; + print (ss, 0); + throw jit_fail_exception (ss.str ()); + } + + if (infered != type ()) + { + stash_type (infered); + return true; + } + + return false; +} + +#endif
new file mode 100644 --- /dev/null +++ b/src/jit-ir.h @@ -0,0 +1,1247 @@ +/* + +Copyright (C) 2012 Max Brister <max@2bass.com> + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +<http://www.gnu.org/licenses/>. + +*/ + +#if !defined (octave_jit_ir_h) +#define octave_jit_ir_h 1 + +#ifdef HAVE_LLVM + +#include <list> +#include <stack> +#include <set> + +#include "jit-typeinfo.h" + +// The low level octave jit ir +// this ir is close to llvm, but contains information for doing type inference. +// We convert the octave parse tree to this IR directly. + +#define JIT_VISIT_IR_NOTEMPLATE \ + JIT_METH(block); \ + JIT_METH(branch); \ + JIT_METH(cond_branch); \ + JIT_METH(call); \ + JIT_METH(extract_argument); \ + JIT_METH(store_argument); \ + JIT_METH(phi); \ + JIT_METH(variable); \ + JIT_METH(error_check); \ + JIT_METH(assign) \ + JIT_METH(argument) + +#define JIT_VISIT_IR_CONST \ + JIT_METH(const_bool); \ + JIT_METH(const_scalar); \ + JIT_METH(const_complex); \ + JIT_METH(const_index); \ + JIT_METH(const_string); \ + JIT_METH(const_range) + +#define JIT_VISIT_IR_CLASSES \ + JIT_VISIT_IR_NOTEMPLATE \ + JIT_VISIT_IR_CONST + +// forward declare all ir classes +#define JIT_METH(cname) \ + class jit_ ## cname; + +JIT_VISIT_IR_NOTEMPLATE + +#undef JIT_METH + +class jit_convert; + +// ABCs which aren't included in JIT_VISIT_IR_ALL +class jit_instruction; +class jit_terminator; + +template <typename T, jit_type *(*EXTRACT_T)(void), typename PASS_T = T, + bool QUOTE=false> +class jit_const; + +typedef jit_const<bool, jit_typeinfo::get_bool> jit_const_bool; +typedef jit_const<double, jit_typeinfo::get_scalar> jit_const_scalar; +typedef jit_const<Complex, jit_typeinfo::get_complex> jit_const_complex; +typedef jit_const<octave_idx_type, jit_typeinfo::get_index> jit_const_index; + +typedef jit_const<std::string, jit_typeinfo::get_string, const std::string&, + true> jit_const_string; +typedef jit_const<jit_range, jit_typeinfo::get_range, const jit_range&> +jit_const_range; + +class jit_ir_walker; +class jit_use; + +class +jit_value : public jit_internal_list<jit_value, jit_use> +{ +public: + jit_value (void) : llvm_value (0), ty (0), mlast_use (0), + min_worklist (false) {} + + virtual ~jit_value (void); + + bool in_worklist (void) const + { + return min_worklist; + } + + void stash_in_worklist (bool ain_worklist) + { + min_worklist = ain_worklist; + } + + // The block of the first use which is not a jit_error_check + // So this is not necessarily first_use ()->parent (). + jit_block *first_use_block (void); + + // replace all uses with + virtual void replace_with (jit_value *value); + + jit_type *type (void) const { return ty; } + + llvm::Type *type_llvm (void) const + { + return ty ? ty->to_llvm () : 0; + } + + const std::string& type_name (void) const + { + return ty->name (); + } + + void stash_type (jit_type *new_ty) { ty = new_ty; } + + std::string print_string (void) + { + std::stringstream ss; + print (ss); + return ss.str (); + } + + jit_instruction *last_use (void) const { return mlast_use; } + + void stash_last_use (jit_instruction *alast_use) + { + mlast_use = alast_use; + } + + virtual bool needs_release (void) const { return false; } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const = 0; + + virtual std::ostream& short_print (std::ostream& os) const + { return print (os); } + + virtual void accept (jit_ir_walker& walker) = 0; + + bool has_llvm (void) const + { + return llvm_value; + } + + llvm::Value *to_llvm (void) const + { + assert (llvm_value); + return llvm_value; + } + + void stash_llvm (llvm::Value *compiled) + { + llvm_value = compiled; + } + +protected: + std::ostream& print_indent (std::ostream& os, size_t indent = 0) const + { + for (size_t i = 0; i < indent * 8; ++i) + os << " "; + return os; + } + + llvm::Value *llvm_value; +private: + jit_type *ty; + jit_instruction *mlast_use; + bool min_worklist; +}; + +std::ostream& operator<< (std::ostream& os, const jit_value& value); +std::ostream& jit_print (std::ostream& os, jit_value *avalue); + +class +jit_use : public jit_internal_node<jit_value, jit_use> +{ +public: + jit_use (void) : muser (0), mindex (0) {} + + // we should really have a move operator, but not until c++11 :( + jit_use (const jit_use& use) : muser (0), mindex (0) + { + *this = use; + } + + jit_use& operator= (const jit_use& use) + { + stash_value (use.value (), use.user (), use.index ()); + return *this; + } + + size_t index (void) const { return mindex; } + + jit_instruction *user (void) const { return muser; } + + jit_block *user_parent (void) const; + + std::list<jit_block *> user_parent_location (void) const; + + void stash_value (jit_value *avalue, jit_instruction *auser = 0, + size_t aindex = -1) + { + jit_internal_node::stash_value (avalue); + mindex = aindex; + muser = auser; + } +private: + jit_instruction *muser; + size_t mindex; +}; + +class +jit_instruction : public jit_value +{ +public: + // FIXME: this code could be so much pretier with varadic templates... + jit_instruction (void) : mid (next_id ()), mparent (0) + {} + + jit_instruction (size_t nargs) : mid (next_id ()), mparent (0) + { + already_infered.reserve (nargs); + marguments.reserve (nargs); + } + +#define STASH_ARG(i) stash_argument (i, arg ## i); +#define JIT_INSTRUCTION_CTOR(N) \ + jit_instruction (OCT_MAKE_DECL_LIST (jit_value *, arg, N)) \ + : already_infered (N), marguments (N), mid (next_id ()), mparent (0) \ + { \ + OCT_ITERATE_MACRO (STASH_ARG, N); \ + } + + JIT_INSTRUCTION_CTOR(1) + JIT_INSTRUCTION_CTOR(2) + JIT_INSTRUCTION_CTOR(3) + JIT_INSTRUCTION_CTOR(4) + +#undef STASH_ARG +#undef JIT_INSTRUCTION_CTOR + + static void reset_ids (void) + { + next_id (true); + } + + jit_value *argument (size_t i) const + { + return marguments[i].value (); + } + + llvm::Value *argument_llvm (size_t i) const + { + assert (argument (i)); + return argument (i)->to_llvm (); + } + + jit_type *argument_type (size_t i) const + { + return argument (i)->type (); + } + + llvm::Type *argument_type_llvm (size_t i) const + { + assert (argument (i)); + return argument_type (i)->to_llvm (); + } + + std::ostream& print_argument (std::ostream& os, size_t i) const + { + if (argument (i)) + return argument (i)->short_print (os); + else + return os << "NULL"; + } + + void stash_argument (size_t i, jit_value *arg) + { + marguments[i].stash_value (arg, this, i); + } + + void push_argument (jit_value *arg) + { + marguments.push_back (jit_use ()); + stash_argument (marguments.size () - 1, arg); + already_infered.push_back (0); + } + + size_t argument_count (void) const + { + return marguments.size (); + } + + void resize_arguments (size_t acount, jit_value *adefault = 0) + { + size_t old = marguments.size (); + marguments.resize (acount); + already_infered.resize (acount); + + if (adefault) + for (size_t i = old; i < acount; ++i) + stash_argument (i, adefault); + } + + const std::vector<jit_use>& arguments (void) const { return marguments; } + + // argument types which have been infered already + const std::vector<jit_type *>& argument_types (void) const + { return already_infered; } + + virtual void push_variable (void) {} + + virtual void pop_variable (void) {} + + virtual void construct_ssa (void) + { + do_construct_ssa (0, argument_count ()); + } + + virtual bool infer (void) { return false; } + + void remove (void); + + virtual std::ostream& short_print (std::ostream& os) const; + + jit_block *parent (void) const { return mparent; } + + std::list<jit_instruction *>::iterator location (void) const + { + return mlocation; + } + + llvm::BasicBlock *parent_llvm (void) const; + + void stash_parent (jit_block *aparent, + std::list<jit_instruction *>::iterator alocation) + { + mparent = aparent; + mlocation = alocation; + } + + size_t id (void) const { return mid; } +protected: + + // Do SSA replacement on arguments in [start, end) + void do_construct_ssa (size_t start, size_t end); + + std::vector<jit_type *> already_infered; +private: + static size_t next_id (bool reset = false) + { + static size_t ret = 0; + if (reset) + return ret = 0; + + return ret++; + } + + std::vector<jit_use> marguments; + + size_t mid; + jit_block *mparent; + std::list<jit_instruction *>::iterator mlocation; +}; + +// defnie accept methods for subclasses +#define JIT_VALUE_ACCEPT \ + virtual void accept (jit_ir_walker& walker); + +// for use as a dummy argument during conversion to LLVM +class +jit_argument : public jit_value +{ +public: + jit_argument (jit_type *atype, llvm::Value *avalue) + { + stash_type (atype); + stash_llvm (avalue); + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent); + return jit_print (os, type ()) << ": DUMMY"; + } + + JIT_VALUE_ACCEPT; +}; + +template <typename T, jit_type *(*EXTRACT_T)(void), typename PASS_T, + bool QUOTE> +class +jit_const : public jit_value +{ +public: + typedef PASS_T pass_t; + + jit_const (PASS_T avalue) : mvalue (avalue) + { + stash_type (EXTRACT_T ()); + } + + PASS_T value (void) const { return mvalue; } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent); + jit_print (os, type ()) << ": "; + if (QUOTE) + os << "\""; + os << mvalue; + if (QUOTE) + os << "\""; + return os; + } + + JIT_VALUE_ACCEPT; +private: + T mvalue; +}; + +class jit_phi_incomming; + +class +jit_block : public jit_value, public jit_internal_list<jit_block, + jit_phi_incomming> +{ + typedef jit_internal_list<jit_block, jit_phi_incomming> ILIST_T; +public: + typedef std::list<jit_instruction *> instruction_list; + typedef instruction_list::iterator iterator; + typedef instruction_list::const_iterator const_iterator; + + typedef std::set<jit_block *> df_set; + typedef df_set::const_iterator df_iterator; + + static const size_t NO_ID = static_cast<size_t> (-1); + + jit_block (const std::string& aname, size_t avisit_count = 0) + : mvisit_count (avisit_count), mid (NO_ID), idom (0), mname (aname), + malive (false) + {} + + virtual void replace_with (jit_value *value); + + void replace_in_phi (jit_block *ablock, jit_block *with); + + // we have a new internal list, but we want to stay compatable with jit_value + jit_use *first_use (void) const { return jit_value::first_use (); } + + size_t use_count (void) const { return jit_value::use_count (); } + + // if a block is alive, then it might be visited during execution + bool alive (void) const { return malive; } + + void mark_alive (void) { malive = true; } + + // If we can merge with a successor, do so and return the now empty block + jit_block *maybe_merge (); + + // merge another block into this block, leaving the merge block empty + void merge (jit_block& merge); + + const std::string& name (void) const { return mname; } + + jit_instruction *prepend (jit_instruction *instr); + + jit_instruction *prepend_after_phi (jit_instruction *instr); + + template <typename T> + T *append (T *instr) + { + internal_append (instr); + return instr; + } + + jit_instruction *insert_before (iterator loc, jit_instruction *instr); + + jit_instruction *insert_before (jit_instruction *loc, jit_instruction *instr) + { + return insert_before (loc->location (), instr); + } + + jit_instruction *insert_after (iterator loc, jit_instruction *instr); + + jit_instruction *insert_after (jit_instruction *loc, jit_instruction *instr) + { + return insert_after (loc->location (), instr); + } + + iterator remove (iterator iter) + { + jit_instruction *instr = *iter; + iter = instructions.erase (iter); + instr->stash_parent (0, instructions.end ()); + return iter; + } + + jit_terminator *terminator (void) const; + + // is the jump from pred alive? + bool branch_alive (jit_block *asucc) const; + + jit_block *successor (size_t i) const; + + size_t successor_count (void) const; + + iterator begin (void) { return instructions.begin (); } + + const_iterator begin (void) const { return instructions.begin (); } + + iterator end (void) { return instructions.end (); } + + const_iterator end (void) const { return instructions.end (); } + + iterator phi_begin (void); + + iterator phi_end (void); + + iterator nonphi_begin (void); + + // must label before id is valid + size_t id (void) const { return mid; } + + // dominance frontier + const df_set& df (void) const { return mdf; } + + df_iterator df_begin (void) const { return mdf.begin (); } + + df_iterator df_end (void) const { return mdf.end (); } + + // label with a RPO walk + void label (void) + { + size_t number = 0; + label (mvisit_count, number); + } + + void label (size_t avisit_count, size_t& number) + { + if (visited (avisit_count)) + return; + + for (jit_use *use = first_use (); use; use = use->next ()) + { + jit_block *pred = use->user_parent (); + pred->label (avisit_count, number); + } + + mid = number++; + } + + // See for idom computation algorithm + // Cooper, Keith D.; Harvey, Timothy J; and Kennedy, Ken (2001). + // "A Simple, Fast Dominance Algorithm" + void compute_idom (jit_block *entry_block) + { + bool changed; + entry_block->idom = entry_block; + do + changed = update_idom (mvisit_count); + while (changed); + } + + // compute dominance frontier + void compute_df (void) + { + compute_df (mvisit_count); + } + + void create_dom_tree (void) + { + create_dom_tree (mvisit_count); + } + + jit_block *dom_successor (size_t idx) const + { + return dom_succ[idx]; + } + + size_t dom_successor_count (void) const + { + return dom_succ.size (); + } + + // call pop_varaible on all instructions + void pop_all (void); + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent); + short_print (os) << ": %pred = "; + for (jit_use *use = first_use (); use; use = use->next ()) + { + jit_block *pred = use->user_parent (); + os << *pred; + if (use->next ()) + os << ", "; + } + os << std::endl; + + for (const_iterator iter = begin (); iter != end (); ++iter) + { + jit_instruction *instr = *iter; + instr->print (os, indent + 1) << std::endl; + } + return os; + } + + // ... + jit_block *maybe_split (jit_convert& convert, jit_block *asuccessor); + + jit_block *maybe_split (jit_convert& convert, jit_block& asuccessor) + { + return maybe_split (convert, &asuccessor); + } + + // print dominator infomration + std::ostream& print_dom (std::ostream& os) const; + + virtual std::ostream& short_print (std::ostream& os) const + { + os << mname; + if (mid != NO_ID) + os << mid; + return os; + } + + llvm::BasicBlock *to_llvm (void) const; + + std::list<jit_block *>::iterator location (void) const + { return mlocation; } + + void stash_location (std::list<jit_block *>::iterator alocation) + { mlocation = alocation; } + + // used to prevent visiting the same node twice in the graph + size_t visit_count (void) const { return mvisit_count; } + + // check if this node has been visited yet at the given visit count. If we + // have not been visited yet, mark us as visited. + bool visited (size_t avisit_count) + { + if (mvisit_count <= avisit_count) + { + mvisit_count = avisit_count + 1; + return false; + } + + return true; + } + + JIT_VALUE_ACCEPT; +private: + void internal_append (jit_instruction *instr); + + void compute_df (size_t avisit_count); + + bool update_idom (size_t avisit_count); + + void create_dom_tree (size_t avisit_count); + + static jit_block *idom_intersect (jit_block *i, jit_block *j); + + size_t mvisit_count; + size_t mid; + jit_block *idom; + df_set mdf; + std::vector<jit_block *> dom_succ; + std::string mname; + instruction_list instructions; + bool malive; + std::list<jit_block *>::iterator mlocation; +}; + +// keeps track of phi functions that use a block on incomming edges +class +jit_phi_incomming : public jit_internal_node<jit_block, jit_phi_incomming> +{ +public: + jit_phi_incomming (void) : muser (0) {} + + jit_phi_incomming (jit_phi *auser) : muser (auser) {} + + jit_phi_incomming (const jit_phi_incomming& use) : jit_internal_node () + { + *this = use; + } + + jit_phi_incomming& operator= (const jit_phi_incomming& use) + { + stash_value (use.value ()); + muser = use.muser; + return *this; + } + + jit_phi *user (void) const { return muser; } + + jit_block *user_parent (void) const; +private: + jit_phi *muser; +}; + +// A non-ssa variable +class +jit_variable : public jit_value +{ +public: + jit_variable (const std::string& aname) : mname (aname), mlast_use (0) {} + + const std::string &name (void) const { return mname; } + + // manipulate the value_stack, for use during SSA construction. The top of the + // value stack represents the current value for this variable + bool has_top (void) const + { + return ! value_stack.empty (); + } + + jit_value *top (void) const + { + return value_stack.top (); + } + + void push (jit_instruction *v) + { + value_stack.push (v); + mlast_use = v; + } + + void pop (void) + { + value_stack.pop (); + } + + jit_instruction *last_use (void) const + { + return mlast_use; + } + + void stash_last_use (jit_instruction *instr) + { + mlast_use = instr; + } + + // blocks in which we are used + void use_blocks (jit_block::df_set& result) + { + jit_use *use = first_use (); + while (use) + { + result.insert (use->user_parent ()); + use = use->next (); + } + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + return print_indent (os, indent) << mname; + } + + JIT_VALUE_ACCEPT; +private: + std::string mname; + std::stack<jit_value *> value_stack; + jit_instruction *mlast_use; +}; + +class +jit_assign_base : public jit_instruction +{ +public: + jit_assign_base (jit_variable *adest) : jit_instruction (), mdest (adest) {} + + jit_assign_base (jit_variable *adest, size_t npred) : jit_instruction (npred), + mdest (adest) {} + + jit_assign_base (jit_variable *adest, jit_value *arg0, jit_value *arg1) + : jit_instruction (arg0, arg1), mdest (adest) {} + + jit_variable *dest (void) const { return mdest; } + + virtual void push_variable (void) + { + mdest->push (this); + } + + virtual void pop_variable (void) + { + mdest->pop (); + } + + virtual std::ostream& short_print (std::ostream& os) const + { + if (type ()) + jit_print (os, type ()) << ": "; + + dest ()->short_print (os); + return os << "#" << id (); + } +private: + jit_variable *mdest; +}; + +class +jit_assign : public jit_assign_base +{ +public: + jit_assign (jit_variable *adest, jit_value *asrc) + : jit_assign_base (adest, adest, asrc), martificial (false) {} + + jit_value *overwrite (void) const + { + return argument (0); + } + + jit_value *src (void) const + { + return argument (1); + } + + // variables don't get modified in an SSA, but COW requires we modify + // variables. An artificial assign is for when a variable gets modified. We + // need an assign in the SSA, but the reference counts shouldn't be updated. + bool artificial (void) const { return martificial; } + + void mark_artificial (void) { martificial = true; } + + virtual bool infer (void) + { + jit_type *stype = src ()->type (); + if (stype != type()) + { + stash_type (stype); + return true; + } + + return false; + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent) << *this << " = " << *src (); + + if (artificial ()) + os << " [artificial]"; + + return os; + } + + JIT_VALUE_ACCEPT; +private: + bool martificial; +}; + +class +jit_phi : public jit_assign_base +{ +public: + jit_phi (jit_variable *adest, size_t npred) + : jit_assign_base (adest, npred) + { + mincomming.reserve (npred); + } + + // removes arguments form dead incomming jumps + bool prune (void); + + void add_incomming (jit_block *from, jit_value *value) + { + push_argument (value); + mincomming.push_back (jit_phi_incomming (this)); + mincomming[mincomming.size () - 1].stash_value (from); + } + + jit_block *incomming (size_t i) const + { + return mincomming[i].value (); + } + + llvm::BasicBlock *incomming_llvm (size_t i) const + { + return incomming (i)->to_llvm (); + } + + virtual void construct_ssa (void) {} + + virtual bool infer (void); + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + std::stringstream ss; + print_indent (ss, indent); + short_print (ss) << " phi "; + std::string ss_str = ss.str (); + std::string indent_str (ss_str.size (), ' '); + os << ss_str; + + for (size_t i = 0; i < argument_count (); ++i) + { + if (i > 0) + os << indent_str; + os << "| "; + + os << *incomming (i) << " -> "; + os << *argument (i); + + if (i + 1 < argument_count ()) + os << std::endl; + } + + return os; + } + + llvm::PHINode *to_llvm (void) const; + + JIT_VALUE_ACCEPT; +private: + std::vector<jit_phi_incomming> mincomming; +}; + +class +jit_terminator : public jit_instruction +{ +public: +#define JIT_TERMINATOR_CONST(N) \ + jit_terminator (size_t asuccessor_count, \ + OCT_MAKE_DECL_LIST (jit_value *, arg, N)) \ + : jit_instruction (OCT_MAKE_ARG_LIST (arg, N)), \ + malive (asuccessor_count, false) {} + + JIT_TERMINATOR_CONST (1) + JIT_TERMINATOR_CONST (2) + JIT_TERMINATOR_CONST (3) + +#undef JIT_TERMINATOR_CONST + + jit_block *successor (size_t idx = 0) const + { + return static_cast<jit_block *> (argument (idx)); + } + + llvm::BasicBlock *successor_llvm (size_t idx = 0) const + { + return successor (idx)->to_llvm (); + } + + size_t successor_index (const jit_block *asuccessor) const; + + std::ostream& print_successor (std::ostream& os, size_t idx = 0) const + { + if (alive (idx)) + os << "[live] "; + else + os << "[dead] "; + + return successor (idx)->short_print (os); + } + + // Check if the jump to successor is live + bool alive (const jit_block *asuccessor) const + { + return alive (successor_index (asuccessor)); + } + + bool alive (size_t idx) const { return malive[idx]; } + + bool alive (int idx) const { return malive[idx]; } + + size_t successor_count (void) const { return malive.size (); } + + virtual bool infer (void); + + llvm::TerminatorInst *to_llvm (void) const; +protected: + virtual bool check_alive (size_t) const { return true; } +private: + std::vector<bool> malive; +}; + +class +jit_branch : public jit_terminator +{ +public: + jit_branch (jit_block *succ) : jit_terminator (1, succ) {} + + virtual size_t successor_count (void) const { return 1; } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent) << "branch: "; + return print_successor (os); + } + + JIT_VALUE_ACCEPT; +}; + +class +jit_cond_branch : public jit_terminator +{ +public: + jit_cond_branch (jit_value *c, jit_block *ctrue, jit_block *cfalse) + : jit_terminator (2, ctrue, cfalse, c) {} + + jit_value *cond (void) const { return argument (2); } + + std::ostream& print_cond (std::ostream& os) const + { + return cond ()->short_print (os); + } + + llvm::Value *cond_llvm (void) const + { + return cond ()->to_llvm (); + } + + virtual size_t successor_count (void) const { return 2; } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent) << "cond_branch: "; + print_cond (os) << ", "; + print_successor (os, 0) << ", "; + return print_successor (os, 1); + } + + JIT_VALUE_ACCEPT; +}; + +class +jit_call : public jit_instruction +{ +public: +#define JIT_CALL_CONST(N) \ + jit_call (const jit_operation& aoperation, \ + OCT_MAKE_DECL_LIST (jit_value *, arg, N)) \ + : jit_instruction (OCT_MAKE_ARG_LIST (arg, N)), moperation (aoperation) {} \ + \ + jit_call (const jit_operation& (*aoperation) (void), \ + OCT_MAKE_DECL_LIST (jit_value *, arg, N)) \ + : jit_instruction (OCT_MAKE_ARG_LIST (arg, N)), moperation (aoperation ()) \ + {} + + JIT_CALL_CONST (1) + JIT_CALL_CONST (2) + JIT_CALL_CONST (3) + JIT_CALL_CONST (4) + +#undef JIT_CALL_CONST + + + const jit_operation& operation (void) const { return moperation; } + + bool can_error (void) const + { + return overload ().can_error (); + } + + const jit_function& overload (void) const + { + return moperation.overload (argument_types ()); + } + + virtual bool needs_release (void) const + { + return type () && jit_typeinfo::get_release (type ()).valid (); + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent); + + if (use_count ()) + short_print (os) << " = "; + os << "call " << moperation.name () << " ("; + + for (size_t i = 0; i < argument_count (); ++i) + { + print_argument (os, i); + if (i + 1 < argument_count ()) + os << ", "; + } + return os << ")"; + } + + virtual bool infer (void); + + JIT_VALUE_ACCEPT; +private: + const jit_operation& moperation; +}; + +// FIXME: This is just ugly... +// checks error_state, if error_state is false then goto the normal branche, +// otherwise goto the error branch +class +jit_error_check : public jit_terminator +{ +public: + jit_error_check (jit_call *acheck_for, jit_block *normal, jit_block *error) + : jit_terminator (2, error, normal, acheck_for) {} + + jit_call *check_for (void) const + { + return static_cast<jit_call *> (argument (2)); + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent) << "error_check " << *check_for () << ", "; + print_successor (os, 1) << ", "; + return print_successor (os, 0); + } + + JIT_VALUE_ACCEPT; +protected: + virtual bool check_alive (size_t idx) const + { + return idx == 1 ? true : check_for ()->can_error (); + } +}; + +class +jit_extract_argument : public jit_assign_base +{ +public: + jit_extract_argument (jit_type *atype, jit_variable *adest) + : jit_assign_base (adest) + { + stash_type (atype); + } + + const std::string& name (void) const + { + return dest ()->name (); + } + + const jit_function& overload (void) const + { + return jit_typeinfo::cast (type (), jit_typeinfo::get_any ()); + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent); + + return short_print (os) << " = extract " << name (); + } + + JIT_VALUE_ACCEPT; +}; + +class +jit_store_argument : public jit_instruction +{ +public: + jit_store_argument (jit_variable *var) + : jit_instruction (var), dest (var) + {} + + const std::string& name (void) const + { + return dest->name (); + } + + const jit_function& overload (void) const + { + return jit_typeinfo::cast (jit_typeinfo::get_any (), result_type ()); + } + + jit_value *result (void) const + { + return argument (0); + } + + jit_type *result_type (void) const + { + return result ()->type (); + } + + llvm::Value *result_llvm (void) const + { + return result ()->to_llvm (); + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + jit_value *res = result (); + print_indent (os, indent) << "store "; + dest->short_print (os); + + if (! isa<jit_variable> (res)) + { + os << " = "; + res->short_print (os); + } + + return os; + } + + JIT_VALUE_ACCEPT; +private: + jit_variable *dest; +}; + +class +jit_ir_walker +{ +public: + virtual ~jit_ir_walker () {} + +#define JIT_METH(clname) \ + virtual void visit (jit_ ## clname&) = 0; + + JIT_VISIT_IR_CLASSES; + +#undef JIT_METH +}; + +template <typename T, jit_type *(*EXTRACT_T)(void), typename PASS_T, bool QUOTE> +void +jit_const<T, EXTRACT_T, PASS_T, QUOTE>::accept (jit_ir_walker& walker) +{ + walker.visit (*this); +} + +#undef JIT_VALUE_ACCEPT + +#endif +#endif
new file mode 100644 --- /dev/null +++ b/src/jit-typeinfo.cc @@ -0,0 +1,1754 @@ +/* + +Copyright (C) 2012 Max Brister <max@2bass.com> + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +<http://www.gnu.org/licenses/>. + +*/ + +// defines required by llvm +#define __STDC_LIMIT_MACROS +#define __STDC_CONSTANT_MACROS + +#ifdef HAVE_CONFIG_H +#include <config.h> +#endif + +#ifdef HAVE_LLVM + +#include "jit-typeinfo.h" + +#include <llvm/Analysis/Verifier.h> +#include <llvm/GlobalVariable.h> +#include <llvm/ExecutionEngine/ExecutionEngine.h> +#include <llvm/LLVMContext.h> +#include <llvm/Function.h> +#include <llvm/Instructions.h> +#include <llvm/Intrinsics.h> +#include <llvm/Support/IRBuilder.h> +#include <llvm/Support/raw_os_ostream.h> + +#include "jit-ir.h" +#include "ov.h" +#include "ov-builtin.h" +#include "ov-complex.h" +#include "ov-scalar.h" +#include "pager.h" + +static llvm::LLVMContext& context = llvm::getGlobalContext (); + +jit_typeinfo *jit_typeinfo::instance = 0; + +std::ostream& jit_print (std::ostream& os, jit_type *atype) +{ + if (! atype) + return os << "null"; + return os << atype->name (); +} + +// function that jit code calls +extern "C" void +octave_jit_print_any (const char *name, octave_base_value *obv) +{ + obv->print_with_name (octave_stdout, name, true); +} + +extern "C" void +octave_jit_print_double (const char *name, double value) +{ + // FIXME: We should avoid allocating a new octave_scalar each time + octave_value ov (value); + ov.print_with_name (octave_stdout, name); +} + +extern "C" octave_base_value* +octave_jit_binary_any_any (octave_value::binary_op op, octave_base_value *lhs, + octave_base_value *rhs) +{ + octave_value olhs (lhs, true); + octave_value orhs (rhs, true); + octave_value result = do_binary_op (op, olhs, orhs); + octave_base_value *rep = result.internal_rep (); + rep->grab (); + return rep; +} + +extern "C" octave_idx_type +octave_jit_compute_nelem (double base, double limit, double inc) +{ + Range rng = Range (base, limit, inc); + return rng.nelem (); +} + +extern "C" void +octave_jit_release_any (octave_base_value *obv) +{ + obv->release (); +} + +extern "C" void +octave_jit_release_matrix (jit_matrix *m) +{ + delete m->array; +} + +extern "C" octave_base_value * +octave_jit_grab_any (octave_base_value *obv) +{ + obv->grab (); + return obv; +} + +extern "C" void +octave_jit_grab_matrix (jit_matrix *result, jit_matrix *m) +{ + *result = *m->array; +} + +extern "C" octave_base_value * +octave_jit_cast_any_matrix (jit_matrix *m) +{ + octave_value ret (*m->array); + octave_base_value *rep = ret.internal_rep (); + rep->grab (); + delete m->array; + + return rep; +} + +extern "C" void +octave_jit_cast_matrix_any (jit_matrix *ret, octave_base_value *obv) +{ + NDArray m = obv->array_value (); + *ret = m; + obv->release (); +} + +extern "C" double +octave_jit_cast_scalar_any (octave_base_value *obv) +{ + double ret = obv->double_value (); + obv->release (); + return ret; +} + +extern "C" octave_base_value * +octave_jit_cast_any_scalar (double value) +{ + return new octave_scalar (value); +} + +extern "C" Complex +octave_jit_cast_complex_any (octave_base_value *obv) +{ + Complex ret = obv->complex_value (); + obv->release (); + return ret; +} + +extern "C" octave_base_value * +octave_jit_cast_any_complex (Complex c) +{ + if (c.imag () == 0) + return new octave_scalar (c.real ()); + else + return new octave_complex (c); +} + +extern "C" void +octave_jit_gripe_nan_to_logical_conversion (void) +{ + try + { + gripe_nan_to_logical_conversion (); + } + catch (const octave_execution_exception&) + { + gripe_library_execution_error (); + } +} + +extern "C" void +octave_jit_ginvalid_index (void) +{ + try + { + gripe_invalid_index (); + } + catch (const octave_execution_exception&) + { + gripe_library_execution_error (); + } +} + +extern "C" void +octave_jit_gindex_range (int nd, int dim, octave_idx_type iext, + octave_idx_type ext) +{ + try + { + gripe_index_out_of_range (nd, dim, iext, ext); + } + catch (const octave_execution_exception&) + { + gripe_library_execution_error (); + } +} + +extern "C" void +octave_jit_paren_subsasgn_impl (jit_matrix *mat, octave_idx_type index, + double value) +{ + NDArray *array = mat->array; + if (array->nelem () < index) + array->resize1 (index); + + double *data = array->fortran_vec (); + data[index - 1] = value; + + mat->update (); +} + +extern "C" void +octave_jit_paren_subsasgn_matrix_range (jit_matrix *result, jit_matrix *mat, + jit_range *index, double value) +{ + NDArray *array = mat->array; + bool done = false; + + // optimize for the simple case (no resizing and no errors) + if (*array->jit_ref_count () == 1 + && index->all_elements_are_ints ()) + { + // this code is similar to idx_vector::fill, but we avoid allocating an + // idx_vector and its associated rep + octave_idx_type start = static_cast<octave_idx_type> (index->base) - 1; + octave_idx_type step = static_cast<octave_idx_type> (index->inc); + octave_idx_type nelem = index->nelem; + octave_idx_type final = start + nelem * step; + if (step < 0) + { + step = -step; + std::swap (final, start); + } + + if (start >= 0 && final < mat->slice_len) + { + done = true; + + double *data = array->jit_slice_data (); + if (step == 1) + std::fill (data + start, data + start + nelem, value); + else + { + for (octave_idx_type i = start; i < final; i += step) + data[i] = value; + } + } + } + + if (! done) + { + idx_vector idx (*index); + NDArray avalue (dim_vector (1, 1)); + avalue.xelem (0) = value; + array->assign (idx, avalue); + } + + result->update (array); +} + +extern "C" Complex +octave_jit_complex_div (Complex lhs, Complex rhs) +{ + // see src/OPERATORS/op-cs-cs.cc + if (rhs == 0.0) + gripe_divide_by_zero (); + + return lhs / rhs; +} + +// FIXME: CP form src/xpow.cc +static inline int +xisint (double x) +{ + return (D_NINT (x) == x + && ((x >= 0 && x < INT_MAX) + || (x <= 0 && x > INT_MIN))); +} + +extern "C" Complex +octave_jit_pow_scalar_scalar (double lhs, double rhs) +{ + // FIXME: almost CP from src/xpow.cc + if (lhs < 0.0 && ! xisint (rhs)) + return std::pow (Complex (lhs), rhs); + return std::pow (lhs, rhs); +} + +extern "C" Complex +octave_jit_pow_complex_complex (Complex lhs, Complex rhs) +{ + if (lhs.imag () == 0 && rhs.imag () == 0) + return octave_jit_pow_scalar_scalar (lhs.real (), rhs.real ()); + return std::pow (lhs, rhs); +} + +extern "C" Complex +octave_jit_pow_complex_scalar (Complex lhs, double rhs) +{ + if (lhs.imag () == 0) + return octave_jit_pow_scalar_scalar (lhs.real (), rhs); + return std::pow (lhs, rhs); +} + +extern "C" Complex +octave_jit_pow_scalar_complex (double lhs, Complex rhs) +{ + if (rhs.imag () == 0) + return octave_jit_pow_scalar_scalar (lhs, rhs.real ()); + return std::pow (lhs, rhs); +} + +extern "C" void +octave_jit_print_matrix (jit_matrix *m) +{ + std::cout << *m << std::endl; +} + +static void +gripe_bad_result (void) +{ + error ("incorrect type information given to the JIT compiler"); +} + +// FIXME: Add support for multiple outputs +extern "C" octave_base_value * +octave_jit_call (octave_builtin::fcn fn, size_t nargin, + octave_base_value **argin, jit_type *result_type) +{ + octave_value_list ovl (nargin); + for (size_t i = 0; i < nargin; ++i) + ovl.xelem (i) = octave_value (argin[i]); + + ovl = fn (ovl, 1); + + // These type checks are not strictly required, but I'm guessing that + // incorrect types will be entered on occasion. This will be very difficult to + // debug unless we do the sanity check here. + if (result_type) + { + if (ovl.length () != 1) + { + gripe_bad_result (); + return 0; + } + + octave_value& result = ovl.xelem (0); + jit_type *jtype = jit_typeinfo::join (jit_typeinfo::type_of (result), + result_type); + if (jtype != result_type) + { + gripe_bad_result (); + return 0; + } + + octave_base_value *ret = result.internal_rep (); + ret->grab (); + return ret; + } + + if (! (ovl.length () == 0 + || (ovl.length () == 1 && ovl.xelem (0).is_undefined ()))) + gripe_bad_result (); + + return 0; +} + +// -------------------- jit_range -------------------- +bool +jit_range::all_elements_are_ints () const +{ + Range r (*this); + return r.all_elements_are_ints (); +} + +std::ostream& +operator<< (std::ostream& os, const jit_range& rng) +{ + return os << "Range[" << rng.base << ", " << rng.limit << ", " << rng.inc + << ", " << rng.nelem << "]"; +} + +// -------------------- jit_matrix -------------------- + +std::ostream& +operator<< (std::ostream& os, const jit_matrix& mat) +{ + return os << "Matrix[" << mat.ref_count << ", " << mat.slice_data << ", " + << mat.slice_len << ", " << mat.dimensions << ", " + << mat.array << "]"; +} + +// -------------------- jit_type -------------------- +jit_type::jit_type (const std::string& aname, jit_type *aparent, + llvm::Type *allvm_type, int aid) : + mname (aname), mparent (aparent), llvm_type (allvm_type), mid (aid), + mdepth (aparent ? aparent->mdepth + 1 : 0) +{ + std::memset (msret, 0, sizeof (msret)); + std::memset (mpointer_arg, 0, sizeof (mpointer_arg)); + std::memset (mpack, 0, sizeof (mpack)); + std::memset (munpack, 0, sizeof (munpack)); + + for (size_t i = 0; i < jit_convention::length; ++i) + mpacked_type[i] = llvm_type; +} + +llvm::Type * +jit_type::to_llvm_arg (void) const +{ + return llvm_type ? llvm_type->getPointerTo () : 0; +} + +// -------------------- jit_function -------------------- +jit_function::jit_function () : module (0), llvm_function (0), mresult (0), + call_conv (jit_convention::length), + mcan_error (false) +{} + +jit_function::jit_function (llvm::Module *amodule, + jit_convention::type acall_conv, + const llvm::Twine& aname, jit_type *aresult, + const std::vector<jit_type *>& aargs) + : module (amodule), mresult (aresult), args (aargs), call_conv (acall_conv), + mcan_error (false) +{ + llvm::SmallVector<llvm::Type *, 15> llvm_args; + + llvm::Type *rtype = llvm::Type::getVoidTy (context); + if (mresult) + { + rtype = mresult->packed_type (call_conv); + if (sret ()) + { + llvm_args.push_back (rtype->getPointerTo ()); + rtype = llvm::Type::getVoidTy (context); + } + } + + for (std::vector<jit_type *>::const_iterator iter = args.begin (); + iter != args.end (); ++iter) + { + jit_type *ty = *iter; + assert (ty); + llvm::Type *argty = ty->packed_type (call_conv); + if (ty->pointer_arg (call_conv)) + argty = argty->getPointerTo (); + + llvm_args.push_back (argty); + } + + // we mark all functinos as external linkage because this prevents llvm + // from getting rid of always inline functions + llvm::FunctionType *ft = llvm::FunctionType::get (rtype, llvm_args, false); + llvm_function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, + aname, module); + if (call_conv == jit_convention::internal) + llvm_function->addFnAttr (llvm::Attribute::AlwaysInline); +} + +jit_function::jit_function (const jit_function& fn, jit_type *aresult, + const std::vector<jit_type *>& aargs) + : module (fn.module), llvm_function (fn.llvm_function), mresult (aresult), + args (aargs), call_conv (fn.call_conv), mcan_error (fn.mcan_error) +{ +} + +jit_function::jit_function (const jit_function& fn) + : module (fn.module), llvm_function (fn.llvm_function), mresult (fn.mresult), + args (fn.args), call_conv (fn.call_conv), mcan_error (fn.mcan_error) +{} + +std::string +jit_function::name (void) const +{ + return llvm_function->getName (); +} + +llvm::BasicBlock * +jit_function::new_block (const std::string& aname, + llvm::BasicBlock *insert_before) +{ + return llvm::BasicBlock::Create (context, aname, llvm_function, + insert_before); +} + +llvm::Value * +jit_function::call (llvm::IRBuilderD& builder, + const std::vector<jit_value *>& in_args) const +{ + assert (in_args.size () == args.size ()); + + std::vector<llvm::Value *> llvm_args (args.size ()); + for (size_t i = 0; i < in_args.size (); ++i) + llvm_args[i] = in_args[i]->to_llvm (); + + return call (builder, llvm_args); +} + +llvm::Value * +jit_function::call (llvm::IRBuilderD& builder, + const std::vector<llvm::Value *>& in_args) const +{ + assert (valid ()); + assert (in_args.size () == args.size ()); + llvm::Function *stacksave + = llvm::Intrinsic::getDeclaration (module, llvm::Intrinsic::stacksave); + llvm::SmallVector<llvm::Value *, 10> llvm_args; + llvm_args.reserve (in_args.size () + sret ()); + + llvm::Value *sret_mem = 0; + llvm::Value *saved_stack = 0; + if (sret ()) + { + saved_stack = builder.CreateCall (stacksave); + sret_mem = builder.CreateAlloca (mresult->packed_type (call_conv)); + llvm_args.push_back (sret_mem); + } + + for (size_t i = 0; i < in_args.size (); ++i) + { + llvm::Value *arg = in_args[i]; + jit_type::convert_fn convert = args[i]->pack (call_conv); + if (convert) + arg = convert (builder, arg); + + if (args[i]->pointer_arg (call_conv)) + { + if (! saved_stack) + saved_stack = builder.CreateCall (stacksave); + + arg = builder.CreateAlloca (args[i]->to_llvm ()); + builder.CreateStore (in_args[i], arg); + } + + llvm_args.push_back (arg); + } + + llvm::Value *ret = builder.CreateCall (llvm_function, llvm_args); + if (sret_mem) + ret = builder.CreateLoad (sret_mem); + + if (mresult) + { + jit_type::convert_fn unpack = mresult->unpack (call_conv); + if (unpack) + ret = unpack (builder, ret); + } + + if (saved_stack) + { + llvm::Function *stackrestore + = llvm::Intrinsic::getDeclaration (module, + llvm::Intrinsic::stackrestore); + builder.CreateCall (stackrestore, saved_stack); + } + + return ret; +} + +llvm::Value * +jit_function::argument (llvm::IRBuilderD& builder, size_t idx) const +{ + assert (idx < args.size ()); + + // FIXME: We should be treating arguments like a list, not a vector. Shouldn't + // matter much for now, as the number of arguments shouldn't be much bigger + // than 4 + llvm::Function::arg_iterator iter = llvm_function->arg_begin (); + if (sret ()) + ++iter; + + for (size_t i = 0; i < idx; ++i, ++iter); + + if (args[idx]->pointer_arg (call_conv)) + return builder.CreateLoad (iter); + + return iter; +} + +void +jit_function::do_return (llvm::IRBuilderD& builder, llvm::Value *rval) +{ + assert (! rval == ! mresult); + + if (rval) + { + jit_type::convert_fn convert = mresult->pack (call_conv); + if (convert) + rval = convert (builder, rval); + + if (sret ()) + builder.CreateStore (rval, llvm_function->arg_begin ()); + else + builder.CreateRet (rval); + } + else + builder.CreateRetVoid (); + + llvm::verifyFunction (*llvm_function); +} + +std::ostream& +operator<< (std::ostream& os, const jit_function& fn) +{ + llvm::Function *lfn = fn.to_llvm (); + os << "jit_function: cc=" << fn.call_conv; + llvm::raw_os_ostream llvm_out (os); + lfn->print (llvm_out); + llvm_out.flush (); + return os; +} + +// -------------------- jit_operation -------------------- +void +jit_operation::add_overload (const jit_function& func, + const std::vector<jit_type*>& args) +{ + if (args.size () >= overloads.size ()) + overloads.resize (args.size () + 1); + + Array<jit_function>& over = overloads[args.size ()]; + dim_vector dv (over.dims ()); + Array<octave_idx_type> idx = to_idx (args); + bool must_resize = false; + + if (dv.length () != idx.numel ()) + { + dv.resize (idx.numel ()); + must_resize = true; + } + + for (octave_idx_type i = 0; i < dv.length (); ++i) + if (dv(i) <= idx(i)) + { + must_resize = true; + dv(i) = idx(i) + 1; + } + + if (must_resize) + over.resize (dv); + + over(idx) = func; +} + +const jit_function& +jit_operation::overload (const std::vector<jit_type*>& types) const +{ + // FIXME: We should search for the next best overload on failure + static jit_function null_overload; + if (types.size () >= overloads.size ()) + return null_overload; + + for (size_t i =0; i < types.size (); ++i) + if (! types[i]) + return null_overload; + + const Array<jit_function>& over = overloads[types.size ()]; + dim_vector dv (over.dims ()); + Array<octave_idx_type> idx = to_idx (types); + for (octave_idx_type i = 0; i < dv.length (); ++i) + if (idx(i) >= dv(i)) + return null_overload; + + return over(idx); +} + +Array<octave_idx_type> +jit_operation::to_idx (const std::vector<jit_type*>& types) const +{ + octave_idx_type numel = types.size (); + if (numel == 1) + numel = 2; + + Array<octave_idx_type> idx (dim_vector (1, numel)); + for (octave_idx_type i = 0; i < static_cast<octave_idx_type> (types.size ()); + ++i) + idx(i) = types[i]->type_id (); + + if (types.size () == 1) + { + idx(1) = idx(0); + idx(0) = 0; + } + + return idx; +} + +// -------------------- jit_typeinfo -------------------- +void +jit_typeinfo::initialize (llvm::Module *m, llvm::ExecutionEngine *e) +{ + new jit_typeinfo (m, e); +} + +jit_typeinfo::jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e) + : module (m), engine (e), next_id (0), + builder (*new llvm::IRBuilderD (context)) +{ + instance = this; + + // FIXME: We should be registering types like in octave_value_typeinfo + llvm::Type *any_t = llvm::StructType::create (context, "octave_base_value"); + any_t = any_t->getPointerTo (); + + llvm::Type *scalar_t = llvm::Type::getDoubleTy (context); + llvm::Type *bool_t = llvm::Type::getInt1Ty (context); + llvm::Type *string_t = llvm::Type::getInt8Ty (context); + string_t = string_t->getPointerTo (); + llvm::Type *index_t = llvm::Type::getIntNTy (context, + sizeof(octave_idx_type) * 8); + + llvm::StructType *range_t = llvm::StructType::create (context, "range"); + std::vector<llvm::Type *> range_contents (4, scalar_t); + range_contents[3] = index_t; + range_t->setBody (range_contents); + + llvm::Type *refcount_t = llvm::Type::getIntNTy (context, sizeof(int) * 8); + + llvm::StructType *matrix_t = llvm::StructType::create (context, "matrix"); + llvm::Type *matrix_contents[5]; + matrix_contents[0] = refcount_t->getPointerTo (); + matrix_contents[1] = scalar_t->getPointerTo (); + matrix_contents[2] = index_t; + matrix_contents[3] = index_t->getPointerTo (); + matrix_contents[4] = string_t; + matrix_t->setBody (llvm::makeArrayRef (matrix_contents, 5)); + + llvm::Type *complex_t = llvm::VectorType::get (scalar_t, 2); + + // complex_ret is what is passed to C functions in order to get calling + // convention right + complex_ret = llvm::StructType::create (context, "complex_ret"); + llvm::Type *complex_ret_contents[] = {scalar_t, scalar_t}; + complex_ret->setBody (complex_ret_contents); + + // create types + any = new_type ("any", 0, any_t); + matrix = new_type ("matrix", any, matrix_t); + complex = new_type ("complex", any, complex_t); + scalar = new_type ("scalar", complex, scalar_t); + range = new_type ("range", any, range_t); + string = new_type ("string", any, string_t); + boolean = new_type ("bool", any, bool_t); + index = new_type ("index", any, index_t); + + create_int (8); + create_int (16); + create_int (32); + create_int (64); + + casts.resize (next_id + 1); + identities.resize (next_id + 1); + + // specify calling conventions + // FIXME: We should detect architecture and do something sane based on that + // here we assume x86 or x86_64 + matrix->mark_sret (); + matrix->mark_pointer_arg (); + + range->mark_sret (); + range->mark_pointer_arg (); + + complex->set_pack (jit_convention::external, &jit_typeinfo::pack_complex); + complex->set_unpack (jit_convention::external, &jit_typeinfo::unpack_complex); + complex->set_packed_type (jit_convention::external, complex_ret); + + if (sizeof (void *) == 4) + complex->mark_sret (); + + // bind global variables + lerror_state = new llvm::GlobalVariable (*module, bool_t, false, + llvm::GlobalValue::ExternalLinkage, + 0, "error_state"); + engine->addGlobalMapping (lerror_state, + reinterpret_cast<void *> (&error_state)); + + // any with anything is an any op + jit_function fn; + jit_type *binary_op_type = intN (sizeof (octave_value::binary_op) * 8); + llvm::Type *llvm_bo_type = binary_op_type->to_llvm (); + jit_function any_binary = create_function (jit_convention::external, + "octave_jit_binary_any_any", + any, binary_op_type, any, any); + any_binary.mark_can_error (); + binary_ops.resize (octave_value::num_binary_ops); + for (size_t i = 0; i < octave_value::num_binary_ops; ++i) + { + octave_value::binary_op op = static_cast<octave_value::binary_op> (i); + std::string op_name = octave_value::binary_op_as_string (op); + binary_ops[i].stash_name ("binary" + op_name); + } + + for (int op = 0; op < octave_value::num_binary_ops; ++op) + { + llvm::Twine fn_name ("octave_jit_binary_any_any_"); + fn_name = fn_name + llvm::Twine (op); + + fn = create_function (jit_convention::internal, fn_name, any, any, any); + fn.mark_can_error (); + llvm::BasicBlock *block = fn.new_block (); + builder.SetInsertPoint (block); + llvm::APInt op_int(sizeof (octave_value::binary_op) * 8, op, + std::numeric_limits<octave_value::binary_op>::is_signed); + llvm::Value *op_as_llvm = llvm::ConstantInt::get (llvm_bo_type, op_int); + llvm::Value *ret = any_binary.call (builder, op_as_llvm, + fn.argument (builder, 0), + fn.argument (builder, 1)); + fn.do_return (builder, ret); + binary_ops[op].add_overload (fn); + } + + // grab any + fn = create_function (jit_convention::external, "octave_jit_grab_any", any, + any); + grab_fn.add_overload (fn); + grab_fn.stash_name ("grab"); + + // grab matrix + fn = create_function (jit_convention::external, "octave_jit_grab_matrix", + matrix, matrix); + grab_fn.add_overload (fn); + + // release any + fn = create_function (jit_convention::external, "octave_jit_release_any", 0, + any); + release_fn.add_overload (fn); + release_fn.stash_name ("release"); + + // release matrix + fn = create_function (jit_convention::external, "octave_jit_release_matrix", + 0, matrix); + release_fn.add_overload (fn); + + // release scalar + fn = create_identity (scalar); + release_fn.add_overload (fn); + + // release complex + fn = create_identity (complex); + release_fn.add_overload (fn); + + // release index + fn = create_identity (index); + release_fn.add_overload (fn); + + // now for binary scalar operations + // FIXME: Finish all operations + add_binary_op (scalar, octave_value::op_add, llvm::Instruction::FAdd); + add_binary_op (scalar, octave_value::op_sub, llvm::Instruction::FSub); + add_binary_op (scalar, octave_value::op_mul, llvm::Instruction::FMul); + add_binary_op (scalar, octave_value::op_el_mul, llvm::Instruction::FMul); + + add_binary_fcmp (scalar, octave_value::op_lt, llvm::CmpInst::FCMP_ULT); + add_binary_fcmp (scalar, octave_value::op_le, llvm::CmpInst::FCMP_ULE); + add_binary_fcmp (scalar, octave_value::op_eq, llvm::CmpInst::FCMP_UEQ); + add_binary_fcmp (scalar, octave_value::op_ge, llvm::CmpInst::FCMP_UGE); + add_binary_fcmp (scalar, octave_value::op_gt, llvm::CmpInst::FCMP_UGT); + add_binary_fcmp (scalar, octave_value::op_ne, llvm::CmpInst::FCMP_UNE); + + jit_function gripe_div0 = create_function (jit_convention::external, + "gripe_divide_by_zero", 0); + gripe_div0.mark_can_error (); + + // divide is annoying because it might error + fn = create_function (jit_convention::internal, + "octave_jit_div_scalar_scalar", scalar, scalar, scalar); + fn.mark_can_error (); + + llvm::BasicBlock *body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::BasicBlock *warn_block = fn.new_block ("warn"); + llvm::BasicBlock *normal_block = fn.new_block ("normal"); + + llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0); + llvm::Value *check = builder.CreateFCmpUEQ (zero, fn.argument (builder, 0)); + builder.CreateCondBr (check, warn_block, normal_block); + + builder.SetInsertPoint (warn_block); + gripe_div0.call (builder); + builder.CreateBr (normal_block); + + builder.SetInsertPoint (normal_block); + llvm::Value *ret = builder.CreateFDiv (fn.argument (builder, 0), + fn.argument (builder, 1)); + fn.do_return (builder, ret); + } + binary_ops[octave_value::op_div].add_overload (fn); + binary_ops[octave_value::op_el_div].add_overload (fn); + + // ldiv is the same as div with the operators reversed + fn = mirror_binary (fn); + binary_ops[octave_value::op_ldiv].add_overload (fn); + binary_ops[octave_value::op_el_ldiv].add_overload (fn); + + // In general, the result of scalar ^ scalar is a complex number. We might be + // able to improve on this if we keep track of the range of values varaibles + // can take on. + fn = create_function (jit_convention::external, + "octave_jit_pow_scalar_scalar", complex, scalar, + scalar); + binary_ops[octave_value::op_pow].add_overload (fn); + binary_ops[octave_value::op_el_pow].add_overload (fn); + + // now for binary complex operations + add_binary_op (complex, octave_value::op_add, llvm::Instruction::FAdd); + add_binary_op (complex, octave_value::op_sub, llvm::Instruction::FSub); + + fn = create_function (jit_convention::internal, + "octave_jit_*_complex_complex", complex, complex, + complex); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + // (x0*x1 - y0*y1, x0*y1 + y0*x1) = (x0,y0) * (x1,y1) + // We compute this in one vectorized multiplication, a subtraction, and an + // addition. + llvm::Value *lhs = fn.argument (builder, 0); + llvm::Value *rhs = fn.argument (builder, 1); + + // FIXME: We need a better way of doing this, working with llvm's IR + // directly is sort of a pain. + llvm::Value *zero = builder.getInt32 (0); + llvm::Value *one = builder.getInt32 (1); + llvm::Value *two = builder.getInt32 (2); + llvm::Value *three = builder.getInt32 (3); + + llvm::Type *vec4 = llvm::VectorType::get (scalar_t, 4); + llvm::Value *mlhs = llvm::UndefValue::get (vec4); + llvm::Value *mrhs = mlhs; + + llvm::Value *temp = complex_real (lhs); + mlhs = builder.CreateInsertElement (mlhs, temp, zero); + mlhs = builder.CreateInsertElement (mlhs, temp, two); + temp = complex_imag (lhs); + mlhs = builder.CreateInsertElement (mlhs, temp, one); + mlhs = builder.CreateInsertElement (mlhs, temp, three); + + temp = complex_real (rhs); + mrhs = builder.CreateInsertElement (mrhs, temp, zero); + mrhs = builder.CreateInsertElement (mrhs, temp, three); + temp = complex_imag (rhs); + mrhs = builder.CreateInsertElement (mrhs, temp, one); + mrhs = builder.CreateInsertElement (mrhs, temp, two); + + llvm::Value *mres = builder.CreateFMul (mlhs, mrhs); + llvm::Value *tlhs = builder.CreateExtractElement (mres, zero); + llvm::Value *trhs = builder.CreateExtractElement (mres, one); + llvm::Value *ret_real = builder.CreateFSub (tlhs, trhs); + + tlhs = builder.CreateExtractElement (mres, two); + trhs = builder.CreateExtractElement (mres, three); + llvm::Value *ret_imag = builder.CreateFAdd (tlhs, trhs); + fn.do_return (builder, complex_new (ret_real, ret_imag)); + } + + binary_ops[octave_value::op_mul].add_overload (fn); + binary_ops[octave_value::op_el_mul].add_overload (fn); + + jit_function complex_div = create_function (jit_convention::external, + "octave_jit_complex_div", + complex, complex, complex); + complex_div.mark_can_error (); + binary_ops[octave_value::op_div].add_overload (fn); + binary_ops[octave_value::op_ldiv].add_overload (fn); + + fn = mirror_binary (complex_div); + binary_ops[octave_value::op_ldiv].add_overload (fn); + binary_ops[octave_value::op_el_ldiv].add_overload (fn); + + fn = create_function (jit_convention::external, + "octave_jit_pow_complex_complex", complex, complex, + complex); + binary_ops[octave_value::op_pow].add_overload (fn); + binary_ops[octave_value::op_el_pow].add_overload (fn); + + fn = create_function (jit_convention::internal, + "octave_jit_*_scalar_complex", complex, scalar, + complex); + jit_function mul_scalar_complex = fn; + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *lhs = fn.argument (builder, 0); + llvm::Value *tlhs = complex_new (lhs, lhs); + llvm::Value *rhs = fn.argument (builder, 1); + fn.do_return (builder, builder.CreateFMul (tlhs, rhs)); + } + binary_ops[octave_value::op_mul].add_overload (fn); + binary_ops[octave_value::op_el_mul].add_overload (fn); + + + fn = mirror_binary (mul_scalar_complex); + binary_ops[octave_value::op_mul].add_overload (fn); + binary_ops[octave_value::op_el_mul].add_overload (fn); + + fn = create_function (jit_convention::internal, "octave_jit_+_scalar_complex", + complex, scalar, complex); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *lhs = fn.argument (builder, 0); + llvm::Value *rhs = fn.argument (builder, 1); + llvm::Value *real = builder.CreateFAdd (lhs, complex_real (rhs)); + fn.do_return (builder, complex_real (rhs, real)); + } + binary_ops[octave_value::op_add].add_overload (fn); + + fn = mirror_binary (fn); + binary_ops[octave_value::op_add].add_overload (fn); + + fn = create_function (jit_convention::internal, "octave_jit_-_complex_scalar", + complex, complex, scalar); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *lhs = fn.argument (builder, 0); + llvm::Value *rhs = fn.argument (builder, 1); + llvm::Value *real = builder.CreateFSub (complex_real (lhs), rhs); + fn.do_return (builder, complex_real (lhs, real)); + } + binary_ops[octave_value::op_sub].add_overload (fn); + + fn = create_function (jit_convention::internal, "octave_jit_-_scalar_complex", + complex, scalar, complex); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *lhs = fn.argument (builder, 0); + llvm::Value *rhs = fn.argument (builder, 1); + llvm::Value *real = builder.CreateFSub (lhs, complex_real (rhs)); + fn.do_return (builder, complex_real (rhs, real)); + } + binary_ops[octave_value::op_sub].add_overload (fn); + + fn = create_function (jit_convention::external, + "octave_jit_pow_scalar_complex", complex, scalar, + complex); + binary_ops[octave_value::op_pow].add_overload (fn); + binary_ops[octave_value::op_el_pow].add_overload (fn); + + fn = create_function (jit_convention::external, + "octave_jit_pow_complex_scalar", complex, complex, + scalar); + binary_ops[octave_value::op_pow].add_overload (fn); + binary_ops[octave_value::op_el_pow].add_overload (fn); + + // now for binary index operators + add_binary_op (index, octave_value::op_add, llvm::Instruction::Add); + + // and binary bool operators + add_binary_op (boolean, octave_value::op_el_or, llvm::Instruction::Or); + add_binary_op (boolean, octave_value::op_el_and, llvm::Instruction::And); + + // now for printing functions + print_fn.stash_name ("print"); + add_print (any); + add_print (scalar); + + // initialize for loop + for_init_fn.stash_name ("for_init"); + + fn = create_function (jit_convention::internal, "octave_jit_for_range_init", + index, range); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *zero = llvm::ConstantInt::get (index_t, 0); + fn.do_return (builder, zero); + } + for_init_fn.add_overload (fn); + + // bounds check for for loop + for_check_fn.stash_name ("for_check"); + + fn = create_function (jit_convention::internal, "octave_jit_for_range_check", + boolean, range, index); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *nelem + = builder.CreateExtractValue (fn.argument (builder, 0), 3); + llvm::Value *idx = fn.argument (builder, 1); + llvm::Value *ret = builder.CreateICmpULT (idx, nelem); + fn.do_return (builder, ret); + } + for_check_fn.add_overload (fn); + + // index variabe for for loop + for_index_fn.stash_name ("for_index"); + + fn = create_function (jit_convention::internal, "octave_jit_for_range_idx", + scalar, range, index); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *idx = fn.argument (builder, 1); + llvm::Value *didx = builder.CreateSIToFP (idx, scalar_t); + llvm::Value *rng = fn.argument (builder, 0); + llvm::Value *base = builder.CreateExtractValue (rng, 0); + llvm::Value *inc = builder.CreateExtractValue (rng, 2); + + llvm::Value *ret = builder.CreateFMul (didx, inc); + ret = builder.CreateFAdd (base, ret); + fn.do_return (builder, ret); + } + for_index_fn.add_overload (fn); + + // logically true + logically_true_fn.stash_name ("logically_true"); + + jit_function gripe_nantl + = create_function (jit_convention::external, + "octave_jit_gripe_nan_to_logical_conversion", 0); + gripe_nantl.mark_can_error (); + + fn = create_function (jit_convention::internal, + "octave_jit_logically_true_scalar", boolean, scalar); + fn.mark_can_error (); + + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::BasicBlock *error_block = fn.new_block ("error"); + llvm::BasicBlock *normal_block = fn.new_block ("normal"); + + llvm::Value *check = builder.CreateFCmpUNE (fn.argument (builder, 0), + fn.argument (builder, 0)); + builder.CreateCondBr (check, error_block, normal_block); + + builder.SetInsertPoint (error_block); + gripe_nantl.call (builder); + builder.CreateBr (normal_block); + builder.SetInsertPoint (normal_block); + + llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0); + llvm::Value *ret = builder.CreateFCmpONE (fn.argument (builder, 0), zero); + fn.do_return (builder, ret); + } + logically_true_fn.add_overload (fn); + + // logically_true boolean + fn = create_identity (boolean); + logically_true_fn.add_overload (fn); + + // make_range + // FIXME: May be benificial to implement all in LLVM + make_range_fn.stash_name ("make_range"); + jit_function compute_nelem + = create_function (jit_convention::external, "octave_jit_compute_nelem", + index, scalar, scalar, scalar); + + fn = create_function (jit_convention::internal, "octave_jit_make_range", + range, scalar, scalar, scalar); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *base = fn.argument (builder, 0); + llvm::Value *limit = fn.argument (builder, 1); + llvm::Value *inc = fn.argument (builder, 2); + llvm::Value *nelem = compute_nelem.call (builder, base, limit, inc); + + llvm::Value *dzero = llvm::ConstantFP::get (scalar_t, 0); + llvm::Value *izero = llvm::ConstantInt::get (index_t, 0); + llvm::Value *rng = llvm::ConstantStruct::get (range_t, dzero, dzero, dzero, + izero, NULL); + rng = builder.CreateInsertValue (rng, base, 0); + rng = builder.CreateInsertValue (rng, limit, 1); + rng = builder.CreateInsertValue (rng, inc, 2); + rng = builder.CreateInsertValue (rng, nelem, 3); + fn.do_return (builder, rng); + } + make_range_fn.add_overload (fn); + + // paren_subsref + jit_type *jit_int = intN (sizeof (int) * 8); + llvm::Type *int_t = jit_int->to_llvm (); + jit_function ginvalid_index + = create_function (jit_convention::external, "octave_jit_ginvalid_index", + 0); + jit_function gindex_range = create_function (jit_convention::external, + "octave_jit_gindex_range", + 0, jit_int, jit_int, index, + index); + + fn = create_function (jit_convention::internal, "()subsref", scalar, matrix, + scalar); + fn.mark_can_error (); + + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *one = llvm::ConstantInt::get (index_t, 1); + llvm::Value *ione; + if (index_t == int_t) + ione = one; + else + ione = llvm::ConstantInt::get (int_t, 1); + + llvm::Value *undef = llvm::UndefValue::get (scalar_t); + llvm::Value *mat = fn.argument (builder, 0); + llvm::Value *idx = fn.argument (builder, 1); + + // convert index to scalar to integer, and check index >= 1 + llvm::Value *int_idx = builder.CreateFPToSI (idx, index_t); + llvm::Value *check_idx = builder.CreateSIToFP (int_idx, scalar_t); + llvm::Value *cond0 = builder.CreateFCmpUNE (idx, check_idx); + llvm::Value *cond1 = builder.CreateICmpSLT (int_idx, one); + llvm::Value *cond = builder.CreateOr (cond0, cond1); + + llvm::BasicBlock *done = fn.new_block ("done"); + llvm::BasicBlock *conv_error = fn.new_block ("conv_error", done); + llvm::BasicBlock *normal = fn.new_block ("normal", done); + builder.CreateCondBr (cond, conv_error, normal); + + builder.SetInsertPoint (conv_error); + ginvalid_index.call (builder); + builder.CreateBr (done); + + builder.SetInsertPoint (normal); + llvm::Value *len = builder.CreateExtractValue (mat, + llvm::ArrayRef<unsigned> (2)); + cond = builder.CreateICmpSGT (int_idx, len); + + + llvm::BasicBlock *bounds_error = fn.new_block ("bounds_error", done); + llvm::BasicBlock *success = fn.new_block ("success", done); + builder.CreateCondBr (cond, bounds_error, success); + + builder.SetInsertPoint (bounds_error); + gindex_range.call (builder, ione, ione, int_idx, len); + builder.CreateBr (done); + + builder.SetInsertPoint (success); + llvm::Value *data = builder.CreateExtractValue (mat, + llvm::ArrayRef<unsigned> (1)); + llvm::Value *gep = builder.CreateInBoundsGEP (data, int_idx); + llvm::Value *ret = builder.CreateLoad (gep); + builder.CreateBr (done); + + builder.SetInsertPoint (done); + + llvm::PHINode *merge = llvm::PHINode::Create (scalar_t, 3); + builder.Insert (merge); + merge->addIncoming (undef, conv_error); + merge->addIncoming (undef, bounds_error); + merge->addIncoming (ret, success); + fn.do_return (builder, merge); + } + paren_subsref_fn.add_overload (fn); + + // paren subsasgn + paren_subsasgn_fn.stash_name ("()subsasgn"); + + jit_function resize_paren_subsasgn + = create_function (jit_convention::external, + "octave_jit_paren_subsasgn_impl", matrix, index, scalar); + fn = create_function (jit_convention::internal, "octave_jit_paren_subsasgn", + matrix, matrix, scalar, scalar); + fn.mark_can_error (); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *one = llvm::ConstantInt::get (index_t, 1); + + llvm::Value *mat = fn.argument (builder, 0); + llvm::Value *idx = fn.argument (builder, 1); + llvm::Value *value = fn.argument (builder, 2); + + llvm::Value *int_idx = builder.CreateFPToSI (idx, index_t); + llvm::Value *check_idx = builder.CreateSIToFP (int_idx, scalar_t); + llvm::Value *cond0 = builder.CreateFCmpUNE (idx, check_idx); + llvm::Value *cond1 = builder.CreateICmpSLT (int_idx, one); + llvm::Value *cond = builder.CreateOr (cond0, cond1); + + llvm::BasicBlock *done = fn.new_block ("done"); + + llvm::BasicBlock *conv_error = fn.new_block ("conv_error", done); + llvm::BasicBlock *normal = fn.new_block ("normal", done); + builder.CreateCondBr (cond, conv_error, normal); + builder.SetInsertPoint (conv_error); + ginvalid_index.call (builder); + builder.CreateBr (done); + + builder.SetInsertPoint (normal); + llvm::Value *len = builder.CreateExtractValue (mat, + llvm::ArrayRef<unsigned> (2)); + cond0 = builder.CreateICmpSGT (int_idx, len); + + llvm::Value *rcount = builder.CreateExtractValue (mat, 0); + rcount = builder.CreateLoad (rcount); + cond1 = builder.CreateICmpSGT (rcount, one); + cond = builder.CreateOr (cond0, cond1); + + llvm::BasicBlock *bounds_error = fn.new_block ("bounds_error", done); + llvm::BasicBlock *success = fn.new_block ("success", done); + builder.CreateCondBr (cond, bounds_error, success); + + // resize on out of bounds access + builder.SetInsertPoint (bounds_error); + llvm::Value *resize_result = resize_paren_subsasgn.call (builder, int_idx, + value); + builder.CreateBr (done); + + builder.SetInsertPoint (success); + llvm::Value *data = builder.CreateExtractValue (mat, + llvm::ArrayRef<unsigned> (1)); + llvm::Value *gep = builder.CreateInBoundsGEP (data, int_idx); + builder.CreateStore (value, gep); + builder.CreateBr (done); + + builder.SetInsertPoint (done); + + llvm::PHINode *merge = llvm::PHINode::Create (matrix_t, 3); + builder.Insert (merge); + merge->addIncoming (mat, conv_error); + merge->addIncoming (resize_result, bounds_error); + merge->addIncoming (mat, success); + fn.do_return (builder, merge); + } + paren_subsasgn_fn.add_overload (fn); + + fn = create_function (jit_convention::external, + "octave_jit_paren_subsasgn_matrix_range", matrix, + matrix, range, scalar); + fn.mark_can_error (); + paren_subsasgn_fn.add_overload (fn); + + casts[any->type_id ()].stash_name ("(any)"); + casts[scalar->type_id ()].stash_name ("(scalar)"); + casts[complex->type_id ()].stash_name ("(complex)"); + casts[matrix->type_id ()].stash_name ("(matrix)"); + + // cast any <- matrix + fn = create_function (jit_convention::external, "octave_jit_cast_any_matrix", + any, matrix); + casts[any->type_id ()].add_overload (fn); + + // cast matrix <- any + fn = create_function (jit_convention::external, "octave_jit_cast_matrix_any", + matrix, any); + casts[matrix->type_id ()].add_overload (fn); + + // cast any <- scalar + fn = create_function (jit_convention::external, "octave_jit_cast_any_scalar", + any, scalar); + casts[any->type_id ()].add_overload (fn); + + // cast scalar <- any + fn = create_function (jit_convention::external, "octave_jit_cast_scalar_any", + scalar, any); + casts[scalar->type_id ()].add_overload (fn); + + // cast any <- complex + fn = create_function (jit_convention::external, "octave_jit_cast_any_complex", + any, complex); + casts[any->type_id ()].add_overload (fn); + + // cast complex <- any + fn = create_function (jit_convention::external, "octave_jit_cast_complex_any", + complex, any); + casts[complex->type_id ()].add_overload (fn); + + // cast complex <- scalar + fn = create_function (jit_convention::internal, + "octave_jit_cast_complex_scalar", complex, scalar); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0); + fn.do_return (builder, complex_new (fn.argument (builder, 0), zero)); + } + casts[complex->type_id ()].add_overload (fn); + + // cast scalar <- complex + fn = create_function (jit_convention::internal, + "octave_jit_cast_scalar_complex", scalar, complex); + body = fn.new_block (); + builder.SetInsertPoint (body); + fn.do_return (builder, complex_real (fn.argument (builder, 0))); + casts[scalar->type_id ()].add_overload (fn); + + // cast any <- any + fn = create_identity (any); + casts[any->type_id ()].add_overload (fn); + + // cast scalar <- scalar + fn = create_identity (scalar); + casts[scalar->type_id ()].add_overload (fn); + + // cast complex <- complex + fn = create_identity (complex); + casts[complex->type_id ()].add_overload (fn); + + // -------------------- builtin functions -------------------- + add_builtin ("#unknown_function"); + unknown_function = builtins["#unknown_function"]; + + add_builtin ("sin"); + register_intrinsic ("sin", llvm::Intrinsic::sin, scalar, scalar); + register_generic ("sin", matrix, matrix); + + add_builtin ("cos"); + register_intrinsic ("cos", llvm::Intrinsic::cos, scalar, scalar); + register_generic ("cos", matrix, matrix); + + add_builtin ("exp"); + register_intrinsic ("exp", llvm::Intrinsic::cos, scalar, scalar); + register_generic ("exp", matrix, matrix); + + casts.resize (next_id + 1); + jit_function any_id = create_identity (any); + jit_function release_any = get_release (any); + std::vector<jit_type *> args; + args.resize (1); + + for (std::map<std::string, jit_type *>::iterator iter = builtins.begin (); + iter != builtins.end (); ++iter) + { + jit_type *btype = iter->second; + args[0] = btype; + + release_fn.add_overload (jit_function (release_any, 0, args)); + casts[any->type_id ()].add_overload (jit_function (any_id, any, args)); + + args[0] = any; + casts[btype->type_id ()].add_overload (jit_function (any_id, btype, + args)); + } +} + +void +jit_typeinfo::add_print (jit_type *ty) +{ + std::stringstream name; + name << "octave_jit_print_" << ty->name (); + jit_function fn = create_function (jit_convention::external, name.str (), 0, + intN (8), ty); + print_fn.add_overload (fn); +} + +// FIXME: cp between add_binary_op, add_binary_icmp, and add_binary_fcmp +void +jit_typeinfo::add_binary_op (jit_type *ty, int op, int llvm_op) +{ + std::stringstream fname; + octave_value::binary_op ov_op = static_cast<octave_value::binary_op>(op); + fname << "octave_jit_" << octave_value::binary_op_as_string (ov_op) + << "_" << ty->name (); + + jit_function fn = create_function (jit_convention::internal, fname.str (), + ty, ty, ty); + llvm::BasicBlock *block = fn.new_block (); + builder.SetInsertPoint (block); + llvm::Instruction::BinaryOps temp + = static_cast<llvm::Instruction::BinaryOps>(llvm_op); + + llvm::Value *ret = builder.CreateBinOp (temp, fn.argument (builder, 0), + fn.argument (builder, 1)); + fn.do_return (builder, ret); + binary_ops[op].add_overload (fn); +} + +void +jit_typeinfo::add_binary_icmp (jit_type *ty, int op, int llvm_op) +{ + std::stringstream fname; + octave_value::binary_op ov_op = static_cast<octave_value::binary_op>(op); + fname << "octave_jit" << octave_value::binary_op_as_string (ov_op) + << "_" << ty->name (); + + jit_function fn = create_function (jit_convention::internal, fname.str (), + boolean, ty, ty); + llvm::BasicBlock *block = fn.new_block (); + builder.SetInsertPoint (block); + llvm::CmpInst::Predicate temp + = static_cast<llvm::CmpInst::Predicate>(llvm_op); + llvm::Value *ret = builder.CreateICmp (temp, fn.argument (builder, 0), + fn.argument (builder, 1)); + fn.do_return (builder, ret); + binary_ops[op].add_overload (fn); +} + +void +jit_typeinfo::add_binary_fcmp (jit_type *ty, int op, int llvm_op) +{ + std::stringstream fname; + octave_value::binary_op ov_op = static_cast<octave_value::binary_op>(op); + fname << "octave_jit" << octave_value::binary_op_as_string (ov_op) + << "_" << ty->name (); + + jit_function fn = create_function (jit_convention::internal, fname.str (), + boolean, ty, ty); + llvm::BasicBlock *block = fn.new_block (); + builder.SetInsertPoint (block); + llvm::CmpInst::Predicate temp + = static_cast<llvm::CmpInst::Predicate>(llvm_op); + llvm::Value *ret = builder.CreateFCmp (temp, fn.argument (builder, 0), + fn.argument (builder, 1)); + fn.do_return (builder, ret); + binary_ops[op].add_overload (fn); +} + +jit_function +jit_typeinfo::create_function (jit_convention::type cc, const llvm::Twine& name, + jit_type *ret, + const std::vector<jit_type *>& args) +{ + jit_function result (module, cc, name, ret, args); + return result; +} + +jit_function +jit_typeinfo::create_identity (jit_type *type) +{ + size_t id = type->type_id (); + if (id >= identities.size ()) + identities.resize (id + 1); + + if (! identities[id].valid ()) + { + jit_function fn = create_function (jit_convention::internal, "id", type, + type); + llvm::BasicBlock *body = fn.new_block (); + builder.SetInsertPoint (body); + fn.do_return (builder, fn.argument (builder, 0)); + return identities[id] = fn; + } + + return identities[id]; +} + +llvm::Value * +jit_typeinfo::do_insert_error_check (llvm::IRBuilderD& builder) +{ + return builder.CreateLoad (lerror_state); +} + +void +jit_typeinfo::add_builtin (const std::string& name) +{ + jit_type *btype = new_type (name, any, any->to_llvm ()); + builtins[name] = btype; + + octave_builtin *ov_builtin = find_builtin (name); + if (ov_builtin) + ov_builtin->stash_jit (*btype); +} + +void +jit_typeinfo::register_intrinsic (const std::string& name, size_t iid, + jit_type *result, + const std::vector<jit_type *>& args) +{ + jit_type *builtin_type = builtins[name]; + size_t nargs = args.size (); + llvm::SmallVector<llvm::Type *, 5> llvm_args (nargs); + for (size_t i = 0; i < nargs; ++i) + llvm_args[i] = args[i]->to_llvm (); + + llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID> (iid); + llvm::Function *ifun = llvm::Intrinsic::getDeclaration (module, id, + llvm_args); + std::stringstream fn_name; + fn_name << "octave_jit_" << name; + + std::vector<jit_type *> args1 (nargs + 1); + args1[0] = builtin_type; + std::copy (args.begin (), args.end (), args1.begin () + 1); + + // The first argument will be the Octave function, but we already know that + // the function call is the equivalent of the intrinsic, so we ignore it and + // call the intrinsic with the remaining arguments. + jit_function fn = create_function (jit_convention::internal, fn_name.str (), + result, args1); + llvm::BasicBlock *body = fn.new_block (); + builder.SetInsertPoint (body); + + llvm::SmallVector<llvm::Value *, 5> fargs (nargs); + for (size_t i = 0; i < nargs; ++i) + fargs[i] = fn.argument (builder, i + 1); + + llvm::Value *ret = builder.CreateCall (ifun, fargs); + fn.do_return (builder, ret); + paren_subsref_fn.add_overload (fn); +} + +octave_builtin * +jit_typeinfo::find_builtin (const std::string& name) +{ + // FIXME: Finalize what we want to store in octave_builtin, then add functions + // to access these values in octave_value + octave_value ov_builtin = symbol_table::find (name); + return dynamic_cast<octave_builtin *> (ov_builtin.internal_rep ()); +} + +void +jit_typeinfo::register_generic (const std::string&, jit_type *, + const std::vector<jit_type *>&) +{ + // FIXME: Implement +} + +jit_function +jit_typeinfo::mirror_binary (const jit_function& fn) +{ + jit_function ret = create_function (jit_convention::internal, + fn.name () + "_reverse", + fn.result (), fn.argument_type (1), + fn.argument_type (0)); + if (fn.can_error ()) + ret.mark_can_error (); + + llvm::BasicBlock *body = ret.new_block (); + builder.SetInsertPoint (body); + llvm::Value *result = fn.call (builder, ret.argument (builder, 1), + ret.argument (builder, 0)); + if (ret.result ()) + ret.do_return (builder, result); + else + ret.do_return (builder); + + return ret; +} + +llvm::Value * +jit_typeinfo::pack_complex (llvm::IRBuilderD& bld, llvm::Value *cplx) +{ + llvm::Type *complex_ret = instance->complex_ret; + llvm::Value *real = bld.CreateExtractElement (cplx, bld.getInt32 (0)); + llvm::Value *imag = bld.CreateExtractElement (cplx, bld.getInt32 (1)); + llvm::Value *ret = llvm::UndefValue::get (complex_ret); + ret = bld.CreateInsertValue (ret, real, 0); + return bld.CreateInsertValue (ret, imag, 1); +} + +llvm::Value * +jit_typeinfo::unpack_complex (llvm::IRBuilderD& bld, llvm::Value *result) +{ + llvm::Type *complex_t = get_complex ()->to_llvm (); + llvm::Value *real = bld.CreateExtractValue (result, 0); + llvm::Value *imag = bld.CreateExtractValue (result, 1); + llvm::Value *ret = llvm::UndefValue::get (complex_t); + ret = bld.CreateInsertElement (ret, real, bld.getInt32 (0)); + return bld.CreateInsertElement (ret, imag, bld.getInt32 (1)); +} + +llvm::Value * +jit_typeinfo::complex_real (llvm::Value *cx) +{ + return builder.CreateExtractElement (cx, builder.getInt32 (0)); +} + +llvm::Value * +jit_typeinfo::complex_real (llvm::Value *cx, llvm::Value *real) +{ + return builder.CreateInsertElement (cx, real, builder.getInt32 (0)); +} + +llvm::Value * +jit_typeinfo::complex_imag (llvm::Value *cx) +{ + return builder.CreateExtractElement (cx, builder.getInt32 (1)); +} + +llvm::Value * +jit_typeinfo::complex_imag (llvm::Value *cx, llvm::Value *imag) +{ + return builder.CreateInsertElement (cx, imag, builder.getInt32 (1)); +} + +llvm::Value * +jit_typeinfo::complex_new (llvm::Value *real, llvm::Value *imag) +{ + llvm::Value *ret = llvm::UndefValue::get (complex->to_llvm ()); + ret = complex_real (ret, real); + return complex_imag (ret, imag); +} + +void +jit_typeinfo::create_int (size_t nbits) +{ + std::stringstream tname; + tname << "int" << nbits; + ints[nbits] = new_type (tname.str (), any, llvm::Type::getIntNTy (context, + nbits)); +} + +jit_type * +jit_typeinfo::intN (size_t nbits) const +{ + std::map<size_t, jit_type *>::const_iterator iter = ints.find (nbits); + if (iter != ints.end ()) + return iter->second; + + throw jit_fail_exception ("No such integer type"); +} + +jit_type * +jit_typeinfo::do_type_of (const octave_value &ov) const +{ + if (ov.is_function ()) + { + // FIXME: This is ugly, we need to finalize how we want to to this, then + // have octave_value fully support the needed functionality + octave_builtin *builtin + = dynamic_cast<octave_builtin *> (ov.internal_rep ()); + return builtin && builtin->to_jit () ? builtin->to_jit () + : unknown_function; + } + + if (ov.is_range ()) + return get_range (); + + if (ov.is_double_type ()) + { + if (ov.is_real_scalar ()) + return get_scalar (); + + if (ov.is_matrix_type ()) + return get_matrix (); + } + + if (ov.is_complex_scalar ()) + return get_complex (); + + return get_any (); +} + +jit_type* +jit_typeinfo::new_type (const std::string& name, jit_type *parent, + llvm::Type *llvm_type) +{ + jit_type *ret = new jit_type (name, parent, llvm_type, next_id++); + id_to_type.push_back (ret); + return ret; +} + +#endif
new file mode 100644 --- /dev/null +++ b/src/jit-typeinfo.h @@ -0,0 +1,661 @@ +/* + +Copyright (C) 2012 Max Brister <max@2bass.com> + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +<http://www.gnu.org/licenses/>. + +*/ + +#if !defined (octave_jit_typeinfo_h) +#define octave_jit_typeinfo_h 1 + +#ifdef HAVE_LLVM + +#include <map> +#include <vector> + +#include "Range.h" +#include "jit-util.h" + +// Defines the type system used by jit and a singleton class, jit_typeinfo, to +// manage the types. +// +// FIXME: +// Operations are defined and implemented in jit_typeinfo. Eventually they +// should be moved elsewhere. (just like with octave_typeinfo) + +// jit_range is compatable with the llvm range structure +struct +jit_range +{ + jit_range (const Range& from) : base (from.base ()), limit (from.limit ()), + inc (from.inc ()), nelem (from.nelem ()) + {} + + operator Range () const + { + return Range (base, limit, inc); + } + + bool all_elements_are_ints () const; + + double base; + double limit; + double inc; + octave_idx_type nelem; +}; + +std::ostream& operator<< (std::ostream& os, const jit_range& rng); + +// jit_array is compatable with the llvm array/matrix structures +template <typename T, typename U> +struct +jit_array +{ + jit_array (T& from) : array (new T (from)) + { + update (); + } + + void update (void) + { + ref_count = array->jit_ref_count (); + slice_data = array->jit_slice_data () - 1; + slice_len = array->capacity (); + dimensions = array->jit_dimensions (); + } + + void update (T *aarray) + { + array = aarray; + update (); + } + + operator T () const + { + return *array; + } + + int *ref_count; + + U *slice_data; + octave_idx_type slice_len; + octave_idx_type *dimensions; + + T *array; +}; + +typedef jit_array<NDArray, double> jit_matrix; + +std::ostream& operator<< (std::ostream& os, const jit_matrix& mat); + +// calling convention +namespace +jit_convention +{ + enum + type + { + // internal to jit + internal, + + // an external C call + external, + + length + }; +} + +// Used to keep track of estimated (infered) types during JIT. This is a +// hierarchical type system which includes both concrete and abstract types. +// +// The types form a lattice. Currently we only allow for one parent type, but +// eventually we may allow for multiple predecessors. +class +jit_type +{ +public: + typedef llvm::Value *(*convert_fn) (llvm::IRBuilderD&, llvm::Value *); + + jit_type (const std::string& aname, jit_type *aparent, llvm::Type *allvm_type, + int aid); + + // a user readable type name + const std::string& name (void) const { return mname; } + + // a unique id for the type + int type_id (void) const { return mid; } + + // An abstract base type, may be null + jit_type *parent (void) const { return mparent; } + + // convert to an llvm type + llvm::Type *to_llvm (void) const { return llvm_type; } + + // how this type gets passed as a function argument + llvm::Type *to_llvm_arg (void) const; + + size_t depth (void) const { return mdepth; } + + // -------------------- Calling Convention information -------------------- + + // A function declared like: mytype foo (int arg0, int arg1); + // Will be converted to: void foo (mytype *retval, int arg0, int arg1) + // if mytype is sret. The caller is responsible for allocating space for + // retval. (on the stack) + bool sret (jit_convention::type cc) const { return msret[cc]; } + + void mark_sret (jit_convention::type cc = jit_convention::external) + { msret[cc] = true; } + + // A function like: void foo (mytype arg0) + // Will be converted to: void foo (mytype *arg0) + // Basically just pass by reference. + bool pointer_arg (jit_convention::type cc) const { return mpointer_arg[cc]; } + + void mark_pointer_arg (jit_convention::type cc = jit_convention::external) + { mpointer_arg[cc] = true; } + + // Convert into an equivalent form before calling. For example, complex is + // represented as two values llvm vector, but we need to pass it as a two + // valued llvm structure to C functions. + convert_fn pack (jit_convention::type cc) { return mpack[cc]; } + + void set_pack (jit_convention::type cc, convert_fn fn) { mpack[cc] = fn; } + + // The inverse operation of pack. + convert_fn unpack (jit_convention::type cc) { return munpack[cc]; } + + void set_unpack (jit_convention::type cc, convert_fn fn) + { munpack[cc] = fn; } + + // The resulting type after pack is called. + llvm::Type *packed_type (jit_convention::type cc) + { return mpacked_type[cc]; } + + void set_packed_type (jit_convention::type cc, llvm::Type *ty) + { mpacked_type[cc] = ty; } +private: + std::string mname; + jit_type *mparent; + llvm::Type *llvm_type; + int mid; + size_t mdepth; + + bool msret[jit_convention::length]; + bool mpointer_arg[jit_convention::length]; + + convert_fn mpack[jit_convention::length]; + convert_fn munpack[jit_convention::length]; + + llvm::Type *mpacked_type[jit_convention::length]; +}; + +// seperate print function to allow easy printing if type is null +std::ostream& jit_print (std::ostream& os, jit_type *atype); + +class jit_value; + +// An abstraction for calling llvm functions with jit_values. Deals with calling +// convention details. +class +jit_function +{ + friend std::ostream& operator<< (std::ostream& os, const jit_function& fn); +public: + // create a function in an invalid state + jit_function (); + + jit_function (llvm::Module *amodule, jit_convention::type acall_conv, + const llvm::Twine& aname, jit_type *aresult, + const std::vector<jit_type *>& aargs); + + // Use an existing function, but change the argument types. The new argument + // types must behave the same for the current calling convention. + jit_function (const jit_function& fn, jit_type *aresult, + const std::vector<jit_type *>& aargs); + + jit_function (const jit_function& fn); + + bool valid (void) const { return llvm_function; } + + std::string name (void) const; + + llvm::BasicBlock *new_block (const std::string& aname = "body", + llvm::BasicBlock *insert_before = 0); + + llvm::Value *call (llvm::IRBuilderD& builder, + const std::vector<jit_value *>& in_args) const; + + llvm::Value *call (llvm::IRBuilderD& builder, + const std::vector<llvm::Value *>& in_args + = std::vector<llvm::Value *> ()) const; + +#define JIT_PARAM_ARGS llvm::IRBuilderD& builder, +#define JIT_PARAMS builder, +#define JIT_CALL(N) JIT_EXPAND (llvm::Value *, call, llvm::Value *, const, N) + + JIT_CALL (1) + JIT_CALL (2) + JIT_CALL (3) + JIT_CALL (4) + JIT_CALL (5) + +#undef JIT_CALL + +#define JIT_CALL(N) JIT_EXPAND (llvm::Value *, call, jit_value *, const, N) + + JIT_CALL (1); + JIT_CALL (2); + +#undef JIT_CALL +#undef JIT_PARAMS +#undef JIT_PARAM_ARGS + + llvm::Value *argument (llvm::IRBuilderD& builder, size_t idx) const; + + void do_return (llvm::IRBuilderD& builder, llvm::Value *rval = 0); + + llvm::Function *to_llvm (void) const { return llvm_function; } + + // If true, then the return value is passed as a pointer in the first argument + bool sret (void) const { return mresult && mresult->sret (call_conv); } + + bool can_error (void) const { return mcan_error; } + + void mark_can_error (void) { mcan_error = true; } + + jit_type *result (void) const { return mresult; } + + jit_type *argument_type (size_t idx) const + { + assert (idx < args.size ()); + return args[idx]; + } + + const std::vector<jit_type *>& arguments (void) const { return args; } +private: + llvm::Module *module; + llvm::Function *llvm_function; + jit_type *mresult; + std::vector<jit_type *> args; + jit_convention::type call_conv; + bool mcan_error; +}; + +std::ostream& operator<< (std::ostream& os, const jit_function& fn); + + +// Keeps track of information about how to implement operations (+, -, *, ect) +// and their resulting types. +class +jit_operation +{ +public: + void add_overload (const jit_function& func) + { + add_overload (func, func.arguments ()); + } + + void add_overload (const jit_function& func, + const std::vector<jit_type*>& args); + + const jit_function& overload (const std::vector<jit_type *>& types) const; + + jit_type *result (const std::vector<jit_type *>& types) const + { + const jit_function& temp = overload (types); + return temp.result (); + } + +#define JIT_PARAMS +#define JIT_PARAM_ARGS +#define JIT_OVERLOAD(N) \ + JIT_EXPAND (const jit_function&, overload, jit_type *, const, N) \ + JIT_EXPAND (jit_type *, result, jit_type *, const, N) + + JIT_OVERLOAD (1); + JIT_OVERLOAD (2); + JIT_OVERLOAD (3); + +#undef JIT_PARAMS +#undef JIT_PARAM_ARGS + + const std::string& name (void) const { return mname; } + + void stash_name (const std::string& aname) { mname = aname; } +private: + Array<octave_idx_type> to_idx (const std::vector<jit_type*>& types) const; + + std::vector<Array<jit_function> > overloads; + + std::string mname; +}; + +// A singleton class which handles the construction of jit_types and +// jit_operations. +class +jit_typeinfo +{ +public: + static void initialize (llvm::Module *m, llvm::ExecutionEngine *e); + + static jit_type *join (jit_type *lhs, jit_type *rhs) + { + return instance->do_join (lhs, rhs); + } + + static jit_type *get_any (void) { return instance->any; } + + static jit_type *get_matrix (void) { return instance->matrix; } + + static jit_type *get_scalar (void) { return instance->scalar; } + + static llvm::Type *get_scalar_llvm (void) + { return instance->scalar->to_llvm (); } + + static jit_type *get_range (void) { return instance->range; } + + static jit_type *get_string (void) { return instance->string; } + + static jit_type *get_bool (void) { return instance->boolean; } + + static jit_type *get_index (void) { return instance->index; } + + static llvm::Type *get_index_llvm (void) + { return instance->index->to_llvm (); } + + static jit_type *get_complex (void) { return instance->complex; } + + // Get the jit_type of an octave_value + static jit_type *type_of (const octave_value& ov) + { + return instance->do_type_of (ov); + } + + static const jit_operation& binary_op (int op) + { + return instance->do_binary_op (op); + } + + static const jit_operation& grab (void) { return instance->grab_fn; } + + static const jit_function& get_grab (jit_type *type) + { + return instance->grab_fn.overload (type); + } + + static const jit_operation& release (void) + { + return instance->release_fn; + } + + static const jit_function& get_release (jit_type *type) + { + return instance->release_fn.overload (type); + } + + static const jit_operation& print_value (void) + { + return instance->print_fn; + } + + static const jit_operation& for_init (void) + { + return instance->for_init_fn; + } + + static const jit_operation& for_check (void) + { + return instance->for_check_fn; + } + + static const jit_operation& for_index (void) + { + return instance->for_index_fn; + } + + static const jit_operation& make_range (void) + { + return instance->make_range_fn; + } + + static const jit_operation& paren_subsref (void) + { + return instance->paren_subsref_fn; + } + + static const jit_operation& paren_subsasgn (void) + { + return instance->paren_subsasgn_fn; + } + + static const jit_operation& logically_true (void) + { + return instance->logically_true_fn; + } + + static const jit_operation& cast (jit_type *result) + { + return instance->do_cast (result); + } + + static const jit_function& cast (jit_type *to, jit_type *from) + { + return instance->do_cast (to, from); + } + + static llvm::Value *insert_error_check (llvm::IRBuilderD& bld) + { + return instance->do_insert_error_check (bld); + } +private: + jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e); + + // FIXME: Do these methods really need to be in jit_typeinfo? + jit_type *do_join (jit_type *lhs, jit_type *rhs) + { + // empty case + if (! lhs) + return rhs; + + if (! rhs) + return lhs; + + // check for a shared parent + while (lhs != rhs) + { + if (lhs->depth () > rhs->depth ()) + lhs = lhs->parent (); + else if (lhs->depth () < rhs->depth ()) + rhs = rhs->parent (); + else + { + // we MUST have depth > 0 as any is the base type of everything + do + { + lhs = lhs->parent (); + rhs = rhs->parent (); + } + while (lhs != rhs); + } + } + + return lhs; + } + + jit_type *do_difference (jit_type *lhs, jit_type *) + { + // FIXME: Maybe we can do something smarter? + return lhs; + } + + jit_type *do_type_of (const octave_value &ov) const; + + const jit_operation& do_binary_op (int op) const + { + assert (static_cast<size_t>(op) < binary_ops.size ()); + return binary_ops[op]; + } + + const jit_operation& do_cast (jit_type *to) + { + static jit_operation null_function; + if (! to) + return null_function; + + size_t id = to->type_id (); + if (id >= casts.size ()) + return null_function; + return casts[id]; + } + + const jit_function& do_cast (jit_type *to, jit_type *from) + { + return do_cast (to).overload (from); + } + + jit_type *new_type (const std::string& name, jit_type *parent, + llvm::Type *llvm_type); + + + void add_print (jit_type *ty); + + void add_binary_op (jit_type *ty, int op, int llvm_op); + + void add_binary_icmp (jit_type *ty, int op, int llvm_op); + + void add_binary_fcmp (jit_type *ty, int op, int llvm_op); + + jit_function create_function (jit_convention::type cc, + const llvm::Twine& name, jit_type *ret, + const std::vector<jit_type *>& args + = std::vector<jit_type *> ()); + +#define JIT_PARAM_ARGS jit_convention::type cc, const llvm::Twine& name, \ + jit_type *ret, +#define JIT_PARAMS cc, name, ret, +#define CREATE_FUNCTION(N) JIT_EXPAND(jit_function, create_function, \ + jit_type *, /* empty */, N) + + CREATE_FUNCTION(1); + CREATE_FUNCTION(2); + CREATE_FUNCTION(3); + CREATE_FUNCTION(4); + +#undef JIT_PARAM_ARGS +#undef JIT_PARAMS +#undef CREATE_FUNCTION + + jit_function create_identity (jit_type *type); + + llvm::Value *do_insert_error_check (llvm::IRBuilderD& bld); + + void add_builtin (const std::string& name); + + void register_intrinsic (const std::string& name, size_t id, + jit_type *result, jit_type *arg0) + { + std::vector<jit_type *> args (1, arg0); + register_intrinsic (name, id, result, args); + } + + void register_intrinsic (const std::string& name, size_t id, jit_type *result, + const std::vector<jit_type *>& args); + + void register_generic (const std::string& name, jit_type *result, + jit_type *arg0) + { + std::vector<jit_type *> args (1, arg0); + register_generic (name, result, args); + } + + void register_generic (const std::string& name, jit_type *result, + const std::vector<jit_type *>& args); + + octave_builtin *find_builtin (const std::string& name); + + jit_function mirror_binary (const jit_function& fn); + + llvm::Function *wrap_complex (llvm::Function *wrap); + + static llvm::Value *pack_complex (llvm::IRBuilderD& bld, + llvm::Value *cplx); + + static llvm::Value *unpack_complex (llvm::IRBuilderD& bld, + llvm::Value *result); + + llvm::Value *complex_real (llvm::Value *cx); + + llvm::Value *complex_real (llvm::Value *cx, llvm::Value *real); + + llvm::Value *complex_imag (llvm::Value *cx); + + llvm::Value *complex_imag (llvm::Value *cx, llvm::Value *imag); + + llvm::Value *complex_new (llvm::Value *real, llvm::Value *imag); + + void create_int (size_t nbits); + + jit_type *intN (size_t nbits) const; + + static jit_typeinfo *instance; + + llvm::Module *module; + llvm::ExecutionEngine *engine; + int next_id; + + llvm::GlobalVariable *lerror_state; + + std::vector<jit_type*> id_to_type; + jit_type *any; + jit_type *matrix; + jit_type *scalar; + jit_type *range; + jit_type *string; + jit_type *boolean; + jit_type *index; + jit_type *complex; + jit_type *unknown_function; + std::map<size_t, jit_type *> ints; + std::map<std::string, jit_type *> builtins; + + llvm::StructType *complex_ret; + + std::vector<jit_operation> binary_ops; + jit_operation grab_fn; + jit_operation release_fn; + jit_operation print_fn; + jit_operation for_init_fn; + jit_operation for_check_fn; + jit_operation for_index_fn; + jit_operation logically_true_fn; + jit_operation make_range_fn; + jit_operation paren_subsref_fn; + jit_operation paren_subsasgn_fn; + + // type id -> cast function TO that type + std::vector<jit_operation> casts; + + // type id -> identity function + std::vector<jit_function> identities; + + llvm::IRBuilderD& builder; +}; + +#endif +#endif
new file mode 100644 --- /dev/null +++ b/src/jit-util.cc @@ -0,0 +1,44 @@ +/* + +Copyright (C) 2012 Max Brister <max@2bass.com> + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +<http://www.gnu.org/licenses/>. + +*/ + +// defines required by llvm +#define __STDC_LIMIT_MACROS +#define __STDC_CONSTANT_MACROS + +#ifdef HAVE_CONFIG_H +#include <config.h> +#endif + +#ifdef HAVE_LLVM + +#include <llvm/Value.h> +#include <llvm/Support/raw_os_ostream.h> + +std::ostream& +operator<< (std::ostream& os, const llvm::Value& v) +{ + llvm::raw_os_ostream llvm_out (os); + v.print (llvm_out); + return os; +} + +#endif
new file mode 100644 --- /dev/null +++ b/src/jit-util.h @@ -0,0 +1,203 @@ +/* + +Copyright (C) 2012 Max Brister <max@2bass.com> + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +<http://www.gnu.org/licenses/>. + +*/ + +// Some utility classes and functions used throughout jit + +#if !defined (octave_jit_util_h) +#define octave_jit_util_h 1 + +#ifdef HAVE_LLVM + +#include <stdexcept> + +// we don't want to include llvm headers here, as they require +// __STDC_LIMIT_MACROS and __STDC_CONSTANT_MACROS be defined in the entire +// compilation unit +namespace llvm +{ + class Value; + class Module; + class FunctionPassManager; + class PassManager; + class ExecutionEngine; + class Function; + class BasicBlock; + class LLVMContext; + class Type; + class StructType; + class Twine; + class GlobalVariable; + class TerminatorInst; + class PHINode; + + class ConstantFolder; + + template <bool preserveNames> + class IRBuilderDefaultInserter; + + template <bool preserveNames, typename T, typename Inserter> + class IRBuilder; + +typedef IRBuilder<true, ConstantFolder, IRBuilderDefaultInserter<true> > +IRBuilderD; +} + +class octave_base_value; +class octave_builtin; +class octave_value; +class tree; +class tree_expression; + +// thrown when we should give up on JIT and interpret +class jit_fail_exception : public std::runtime_error +{ +public: + jit_fail_exception (void) : std::runtime_error ("unknown"), mknown (false) {} + jit_fail_exception (const std::string& reason) : std::runtime_error (reason), + mknown (true) + {} + + bool known (void) const { return mknown; } +private: + bool mknown; +}; + +// llvm doesn't provide this, and it's really useful for debugging +std::ostream& operator<< (std::ostream& os, const llvm::Value& v); + +template <typename HOLDER_T, typename SUB_T> +class jit_internal_node; + +// jit_internal_list and jit_internal_node implement generic embedded doubly +// linked lists. List items extend from jit_internal_list, and can be placed +// in nodes of type jit_internal_node. We use CRTP twice. +template <typename LIST_T, typename NODE_T> +class +jit_internal_list +{ + friend class jit_internal_node<LIST_T, NODE_T>; +public: + jit_internal_list (void) : use_head (0), use_tail (0), muse_count (0) {} + + virtual ~jit_internal_list (void) + { + while (use_head) + use_head->stash_value (0); + } + + NODE_T *first_use (void) const { return use_head; } + + size_t use_count (void) const { return muse_count; } +private: + NODE_T *use_head; + NODE_T *use_tail; + size_t muse_count; +}; + +// a node for internal linked lists +template <typename LIST_T, typename NODE_T> +class +jit_internal_node +{ +public: + typedef jit_internal_list<LIST_T, NODE_T> jit_ilist; + + jit_internal_node (void) : mvalue (0), mnext (0), mprev (0) {} + + ~jit_internal_node (void) { remove (); } + + LIST_T *value (void) const { return mvalue; } + + void stash_value (LIST_T *avalue) + { + remove (); + + mvalue = avalue; + + if (mvalue) + { + jit_ilist *ilist = mvalue; + NODE_T *sthis = static_cast<NODE_T *> (this); + if (ilist->use_head) + { + ilist->use_tail->mnext = sthis; + mprev = ilist->use_tail; + } + else + ilist->use_head = sthis; + + ilist->use_tail = sthis; + ++ilist->muse_count; + } + } + + NODE_T *next (void) const { return mnext; } + + NODE_T *prev (void) const { return mprev; } +private: + void remove () + { + if (mvalue) + { + jit_ilist *ilist = mvalue; + if (mprev) + mprev->mnext = mnext; + else + // we are the use_head + ilist->use_head = mnext; + + if (mnext) + mnext->mprev = mprev; + else + // we are the use tail + ilist->use_tail = mprev; + + mnext = mprev = 0; + --ilist->muse_count; + mvalue = 0; + } + } + + LIST_T *mvalue; + NODE_T *mnext; + NODE_T *mprev; +}; + +// Use like: isa<jit_phi> (value) +// basically just a short cut type typing dyanmic_cast. +template <typename T, typename U> +bool isa (U *value) +{ + return dynamic_cast<T *> (value); +} + +#define JIT_ASSIGN_ARG(i) the_args[i] = arg ## i; +#define JIT_EXPAND(ret, fname, type, isconst, N) \ + ret fname (JIT_PARAM_ARGS OCT_MAKE_DECL_LIST (type, arg, N)) isconst \ + { \ + std::vector<type> the_args (N); \ + OCT_ITERATE_MACRO (JIT_ASSIGN_ARG, N); \ + return fname (JIT_PARAMS the_args); \ + } + +#endif +#endif
--- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -31,2340 +31,31 @@ #include "pt-jit.h" -#include <typeinfo> - -#include <llvm/LLVMContext.h> -#include <llvm/Module.h> -#include <llvm/Function.h> -#include <llvm/BasicBlock.h> -#include <llvm/Intrinsics.h> -#include <llvm/Support/IRBuilder.h> -#include <llvm/ExecutionEngine/ExecutionEngine.h> -#include <llvm/ExecutionEngine/JIT.h> -#include <llvm/PassManager.h> -#include <llvm/Analysis/Verifier.h> #include <llvm/Analysis/CallGraph.h> #include <llvm/Analysis/Passes.h> -#include <llvm/Target/TargetData.h> -#include <llvm/Transforms/Scalar.h> -#include <llvm/Transforms/IPO.h> +#include <llvm/Analysis/Verifier.h> +#include <llvm/ExecutionEngine/ExecutionEngine.h> +#include <llvm/ExecutionEngine/JIT.h> +#include <llvm/Module.h> +#include <llvm/PassManager.h> +#include <llvm/Support/IRBuilder.h> +#include <llvm/Support/raw_os_ostream.h> #include <llvm/Support/TargetSelect.h> -#include <llvm/Support/raw_os_ostream.h> -#include <llvm/Support/FormattedStream.h> -#include <llvm/Bitcode/ReaderWriter.h> +#include <llvm/Target/TargetData.h> +#include <llvm/Transforms/IPO.h> +#include <llvm/Transforms/Scalar.h> -#include "octave.h" -#include "ov-fcn-handle.h" -#include "ov-usr-fcn.h" -#include "ov-builtin.h" -#include "ov-scalar.h" -#include "ov-complex.h" +#ifdef OCTAVE_JIT_DEBUG +#include <llvm/Bitcode/ReaderWriter.h> +#endif + +#include "symtab.h" #include "pt-all.h" -#include "symtab.h" static llvm::IRBuilder<> builder (llvm::getGlobalContext ()); static llvm::LLVMContext& context = llvm::getGlobalContext (); -jit_typeinfo *jit_typeinfo::instance; - -// thrown when we should give up on JIT and interpret -class jit_fail_exception : public std::runtime_error -{ -public: - jit_fail_exception (void) : std::runtime_error ("unknown"), mknown (false) {} - jit_fail_exception (const std::string& reason) : std::runtime_error (reason), - mknown (true) - {} - - bool known (void) const { return mknown; } -private: - bool mknown; -}; - -static void fail (void) GCC_ATTR_NORETURN; -static void fail (const std::string&) GCC_ATTR_NORETURN; - -static void -fail (void) -{ - throw jit_fail_exception (); -} - -#ifdef OCTAVE_JIT_DEBUG -static void -fail (const std::string& reason) -{ - throw jit_fail_exception (reason); -} -#else -static void -fail (const std::string&) -{ - throw jit_fail_exception (); -} -#endif // OCTAVE_JIT_DEBUG - -std::ostream& jit_print (std::ostream& os, jit_type *atype) -{ - if (! atype) - return os << "null"; - return os << atype->name (); -} - -// function that jit code calls -extern "C" void -octave_jit_print_any (const char *name, octave_base_value *obv) -{ - obv->print_with_name (octave_stdout, name, true); -} - -extern "C" void -octave_jit_print_double (const char *name, double value) -{ - // FIXME: We should avoid allocating a new octave_scalar each time - octave_value ov (value); - ov.print_with_name (octave_stdout, name); -} - -extern "C" octave_base_value* -octave_jit_binary_any_any (octave_value::binary_op op, octave_base_value *lhs, - octave_base_value *rhs) -{ - octave_value olhs (lhs, true); - octave_value orhs (rhs, true); - octave_value result = do_binary_op (op, olhs, orhs); - octave_base_value *rep = result.internal_rep (); - rep->grab (); - return rep; -} - -extern "C" octave_idx_type -octave_jit_compute_nelem (double base, double limit, double inc) -{ - Range rng = Range (base, limit, inc); - return rng.nelem (); -} - -extern "C" void -octave_jit_release_any (octave_base_value *obv) -{ - obv->release (); -} - -extern "C" void -octave_jit_release_matrix (jit_matrix *m) -{ - delete m->array; -} - -extern "C" octave_base_value * -octave_jit_grab_any (octave_base_value *obv) -{ - obv->grab (); - return obv; -} - -extern "C" void -octave_jit_grab_matrix (jit_matrix *result, jit_matrix *m) -{ - *result = *m->array; -} - -extern "C" octave_base_value * -octave_jit_cast_any_matrix (jit_matrix *m) -{ - octave_value ret (*m->array); - octave_base_value *rep = ret.internal_rep (); - rep->grab (); - delete m->array; - - return rep; -} - -extern "C" void -octave_jit_cast_matrix_any (jit_matrix *ret, octave_base_value *obv) -{ - NDArray m = obv->array_value (); - *ret = m; - obv->release (); -} - -extern "C" double -octave_jit_cast_scalar_any (octave_base_value *obv) -{ - double ret = obv->double_value (); - obv->release (); - return ret; -} - -extern "C" octave_base_value * -octave_jit_cast_any_scalar (double value) -{ - return new octave_scalar (value); -} - -extern "C" Complex -octave_jit_cast_complex_any (octave_base_value *obv) -{ - Complex ret = obv->complex_value (); - obv->release (); - return ret; -} - -extern "C" octave_base_value * -octave_jit_cast_any_complex (Complex c) -{ - if (c.imag () == 0) - return new octave_scalar (c.real ()); - else - return new octave_complex (c); -} - -extern "C" void -octave_jit_gripe_nan_to_logical_conversion (void) -{ - try - { - gripe_nan_to_logical_conversion (); - } - catch (const octave_execution_exception&) - { - gripe_library_execution_error (); - } -} - -extern "C" void -octave_jit_ginvalid_index (void) -{ - try - { - gripe_invalid_index (); - } - catch (const octave_execution_exception&) - { - gripe_library_execution_error (); - } -} - -extern "C" void -octave_jit_gindex_range (int nd, int dim, octave_idx_type iext, - octave_idx_type ext) -{ - try - { - gripe_index_out_of_range (nd, dim, iext, ext); - } - catch (const octave_execution_exception&) - { - gripe_library_execution_error (); - } -} - -extern "C" void -octave_jit_paren_subsasgn_impl (jit_matrix *mat, octave_idx_type index, - double value) -{ - NDArray *array = mat->array; - if (array->nelem () < index) - array->resize1 (index); - - double *data = array->fortran_vec (); - data[index - 1] = value; - - mat->update (); -} - -extern "C" void -octave_jit_paren_subsasgn_matrix_range (jit_matrix *result, jit_matrix *mat, - jit_range *index, double value) -{ - NDArray *array = mat->array; - bool done = false; - - // optimize for the simple case (no resizing and no errors) - if (*array->jit_ref_count () == 1 - && index->all_elements_are_ints ()) - { - // this code is similar to idx_vector::fill, but we avoid allocating an - // idx_vector and its associated rep - octave_idx_type start = static_cast<octave_idx_type> (index->base) - 1; - octave_idx_type step = static_cast<octave_idx_type> (index->inc); - octave_idx_type nelem = index->nelem; - octave_idx_type final = start + nelem * step; - if (step < 0) - { - step = -step; - std::swap (final, start); - } - - if (start >= 0 && final < mat->slice_len) - { - done = true; - - double *data = array->jit_slice_data (); - if (step == 1) - std::fill (data + start, data + start + nelem, value); - else - { - for (octave_idx_type i = start; i < final; i += step) - data[i] = value; - } - } - } - - if (! done) - { - idx_vector idx (*index); - NDArray avalue (dim_vector (1, 1)); - avalue.xelem (0) = value; - array->assign (idx, avalue); - } - - result->update (array); -} - -extern "C" Complex -octave_jit_complex_div (Complex lhs, Complex rhs) -{ - // see src/OPERATORS/op-cs-cs.cc - if (rhs == 0.0) - gripe_divide_by_zero (); - - return lhs / rhs; -} - -// FIXME: CP form src/xpow.cc -static inline int -xisint (double x) -{ - return (D_NINT (x) == x - && ((x >= 0 && x < INT_MAX) - || (x <= 0 && x > INT_MIN))); -} - -extern "C" Complex -octave_jit_pow_scalar_scalar (double lhs, double rhs) -{ - // FIXME: almost CP from src/xpow.cc - if (lhs < 0.0 && ! xisint (rhs)) - return std::pow (Complex (lhs), rhs); - return std::pow (lhs, rhs); -} - -extern "C" Complex -octave_jit_pow_complex_complex (Complex lhs, Complex rhs) -{ - if (lhs.imag () == 0 && rhs.imag () == 0) - return octave_jit_pow_scalar_scalar (lhs.real (), rhs.real ()); - return std::pow (lhs, rhs); -} - -extern "C" Complex -octave_jit_pow_complex_scalar (Complex lhs, double rhs) -{ - if (lhs.imag () == 0) - return octave_jit_pow_scalar_scalar (lhs.real (), rhs); - return std::pow (lhs, rhs); -} - -extern "C" Complex -octave_jit_pow_scalar_complex (double lhs, Complex rhs) -{ - if (rhs.imag () == 0) - return octave_jit_pow_scalar_scalar (lhs, rhs.real ()); - return std::pow (lhs, rhs); -} - -extern "C" void -octave_jit_print_matrix (jit_matrix *m) -{ - std::cout << *m << std::endl; -} - -static void -gripe_bad_result (void) -{ - error ("incorrect type information given to the JIT compiler"); -} - -// FIXME: Add support for multiple outputs -extern "C" octave_base_value * -octave_jit_call (octave_builtin::fcn fn, size_t nargin, - octave_base_value **argin, jit_type *result_type) -{ - octave_value_list ovl (nargin); - for (size_t i = 0; i < nargin; ++i) - ovl.xelem (i) = octave_value (argin[i]); - - ovl = fn (ovl, 1); - - // These type checks are not strictly required, but I'm guessing that - // incorrect types will be entered on occasion. This will be very difficult to - // debug unless we do the sanity check here. - if (result_type) - { - if (ovl.length () != 1) - { - gripe_bad_result (); - return 0; - } - - octave_value& result = ovl.xelem (0); - jit_type *jtype = jit_typeinfo::join (jit_typeinfo::type_of (result), - result_type); - if (jtype != result_type) - { - gripe_bad_result (); - return 0; - } - - octave_base_value *ret = result.internal_rep (); - ret->grab (); - return ret; - } - - if (! (ovl.length () == 0 - || (ovl.length () == 1 && ovl.xelem (0).is_undefined ()))) - gripe_bad_result (); - - return 0; -} - -std::ostream& -operator<< (std::ostream& os, const llvm::Value& v) -{ - llvm::raw_os_ostream llvm_out (os); - v.print (llvm_out); - return os; -} - -// -------------------- jit_range -------------------- -bool -jit_range::all_elements_are_ints () const -{ - Range r (*this); - return r.all_elements_are_ints (); -} - -std::ostream& -operator<< (std::ostream& os, const jit_range& rng) -{ - return os << "Range[" << rng.base << ", " << rng.limit << ", " << rng.inc - << ", " << rng.nelem << "]"; -} - -// -------------------- jit_matrix -------------------- - -std::ostream& -operator<< (std::ostream& os, const jit_matrix& mat) -{ - return os << "Matrix[" << mat.ref_count << ", " << mat.slice_data << ", " - << mat.slice_len << ", " << mat.dimensions << ", " - << mat.array << "]"; -} - -// -------------------- jit_type -------------------- -jit_type::jit_type (const std::string& aname, jit_type *aparent, - llvm::Type *allvm_type, int aid) : - mname (aname), mparent (aparent), llvm_type (allvm_type), mid (aid), - mdepth (aparent ? aparent->mdepth + 1 : 0) -{ - std::memset (msret, 0, sizeof (msret)); - std::memset (mpointer_arg, 0, sizeof (mpointer_arg)); - std::memset (mpack, 0, sizeof (mpack)); - std::memset (munpack, 0, sizeof (munpack)); - - for (size_t i = 0; i < jit_convention::length; ++i) - mpacked_type[i] = llvm_type; -} - -llvm::Type * -jit_type::to_llvm_arg (void) const -{ - return llvm_type ? llvm_type->getPointerTo () : 0; -} - -// -------------------- jit_function -------------------- -jit_function::jit_function () : module (0), llvm_function (0), mresult (0), - call_conv (jit_convention::length), - mcan_error (false) -{} - -jit_function::jit_function (llvm::Module *amodule, - jit_convention::type acall_conv, - const llvm::Twine& aname, jit_type *aresult, - const std::vector<jit_type *>& aargs) - : module (amodule), mresult (aresult), args (aargs), call_conv (acall_conv), - mcan_error (false) -{ - llvm::SmallVector<llvm::Type *, 15> llvm_args; - - llvm::Type *rtype = builder.getVoidTy (); - if (mresult) - { - rtype = mresult->packed_type (call_conv); - if (sret ()) - { - llvm_args.push_back (rtype->getPointerTo ()); - rtype = builder.getVoidTy (); - } - } - - for (std::vector<jit_type *>::const_iterator iter = args.begin (); - iter != args.end (); ++iter) - { - jit_type *ty = *iter; - assert (ty); - llvm::Type *argty = ty->packed_type (call_conv); - if (ty->pointer_arg (call_conv)) - argty = argty->getPointerTo (); - - llvm_args.push_back (argty); - } - - // we mark all functinos as external linkage because this prevents llvm - // from getting rid of always inline functions - llvm::FunctionType *ft = llvm::FunctionType::get (rtype, llvm_args, false); - llvm_function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, - aname, module); - if (call_conv == jit_convention::internal) - llvm_function->addFnAttr (llvm::Attribute::AlwaysInline); -} - -jit_function::jit_function (const jit_function& fn, jit_type *aresult, - const std::vector<jit_type *>& aargs) - : module (fn.module), llvm_function (fn.llvm_function), mresult (aresult), - args (aargs), call_conv (fn.call_conv), mcan_error (fn.mcan_error) -{ -} - -jit_function::jit_function (const jit_function& fn) - : module (fn.module), llvm_function (fn.llvm_function), mresult (fn.mresult), - args (fn.args), call_conv (fn.call_conv), mcan_error (fn.mcan_error) -{} - -std::string -jit_function::name (void) const -{ - return llvm_function->getName (); -} - -llvm::BasicBlock * -jit_function::new_block (const std::string& aname, - llvm::BasicBlock *insert_before) -{ - return llvm::BasicBlock::Create (context, aname, llvm_function, - insert_before); -} - -llvm::Value * -jit_function::call (const std::vector<jit_value *>& in_args) const -{ - assert (in_args.size () == args.size ()); - - std::vector<llvm::Value *> llvm_args (args.size ()); - for (size_t i = 0; i < in_args.size (); ++i) - llvm_args[i] = in_args[i]->to_llvm (); - - return call (llvm_args); -} - -llvm::Value * -jit_function::call (const std::vector<llvm::Value *>& in_args) const -{ - assert (valid ()); - assert (in_args.size () == args.size ()); - llvm::Function *stacksave - = llvm::Intrinsic::getDeclaration (module, llvm::Intrinsic::stacksave); - llvm::SmallVector<llvm::Value *, 10> llvm_args; - llvm_args.reserve (in_args.size () + sret ()); - - llvm::Value *sret_mem = 0; - llvm::Value *saved_stack = 0; - if (sret ()) - { - saved_stack = builder.CreateCall (stacksave); - sret_mem = builder.CreateAlloca (mresult->packed_type (call_conv)); - llvm_args.push_back (sret_mem); - } - - for (size_t i = 0; i < in_args.size (); ++i) - { - llvm::Value *arg = in_args[i]; - jit_type::convert_fn convert = args[i]->pack (call_conv); - if (convert) - arg = convert (arg); - - if (args[i]->pointer_arg (call_conv)) - { - if (! saved_stack) - saved_stack = builder.CreateCall (stacksave); - - arg = builder.CreateAlloca (args[i]->to_llvm ()); - builder.CreateStore (in_args[i], arg); - } - - llvm_args.push_back (arg); - } - - llvm::Value *ret = builder.CreateCall (llvm_function, llvm_args); - if (sret_mem) - ret = builder.CreateLoad (sret_mem); - - if (mresult) - { - jit_type::convert_fn unpack = mresult->unpack (call_conv); - if (unpack) - ret = unpack (ret); - } - - if (saved_stack) - { - llvm::Function *stackrestore - = llvm::Intrinsic::getDeclaration (module, - llvm::Intrinsic::stackrestore); - builder.CreateCall (stackrestore, saved_stack); - } - - return ret; -} - -llvm::Value * -jit_function::argument (size_t idx) const -{ - assert (idx < args.size ()); - - // FIXME: We should be treating arguments like a list, not a vector. Shouldn't - // matter much for now, as the number of arguments shouldn't be much bigger - // than 4 - llvm::Function::arg_iterator iter = llvm_function->arg_begin (); - if (sret ()) - ++iter; - - for (size_t i = 0; i < idx; ++i, ++iter); - - if (args[idx]->pointer_arg (call_conv)) - return builder.CreateLoad (iter); - - return iter; -} - -void -jit_function::do_return (llvm::Value *rval) -{ - assert (! rval == ! mresult); - - if (rval) - { - jit_type::convert_fn convert = mresult->pack (call_conv); - if (convert) - rval = convert (rval); - - if (sret ()) - builder.CreateStore (rval, llvm_function->arg_begin ()); - else - builder.CreateRet (rval); - } - else - builder.CreateRetVoid (); - - llvm::verifyFunction (*llvm_function); -} - -std::ostream& -operator<< (std::ostream& os, const jit_function& fn) -{ - llvm::Function *lfn = fn.to_llvm (); - os << "jit_function: cc=" << fn.call_conv; - llvm::raw_os_ostream llvm_out (os); - lfn->print (llvm_out); - llvm_out.flush (); - return os; -} - -// -------------------- jit_operation -------------------- -void -jit_operation::add_overload (const jit_function& func, - const std::vector<jit_type*>& args) -{ - if (args.size () >= overloads.size ()) - overloads.resize (args.size () + 1); - - Array<jit_function>& over = overloads[args.size ()]; - dim_vector dv (over.dims ()); - Array<octave_idx_type> idx = to_idx (args); - bool must_resize = false; - - if (dv.length () != idx.numel ()) - { - dv.resize (idx.numel ()); - must_resize = true; - } - - for (octave_idx_type i = 0; i < dv.length (); ++i) - if (dv(i) <= idx(i)) - { - must_resize = true; - dv(i) = idx(i) + 1; - } - - if (must_resize) - over.resize (dv); - - over(idx) = func; -} - -const jit_function& -jit_operation::overload (const std::vector<jit_type*>& types) const -{ - // FIXME: We should search for the next best overload on failure - static jit_function null_overload; - if (types.size () >= overloads.size ()) - return null_overload; - - for (size_t i =0; i < types.size (); ++i) - if (! types[i]) - return null_overload; - - const Array<jit_function>& over = overloads[types.size ()]; - dim_vector dv (over.dims ()); - Array<octave_idx_type> idx = to_idx (types); - for (octave_idx_type i = 0; i < dv.length (); ++i) - if (idx(i) >= dv(i)) - return null_overload; - - return over(idx); -} - -Array<octave_idx_type> -jit_operation::to_idx (const std::vector<jit_type*>& types) const -{ - octave_idx_type numel = types.size (); - if (numel == 1) - numel = 2; - - Array<octave_idx_type> idx (dim_vector (1, numel)); - for (octave_idx_type i = 0; i < static_cast<octave_idx_type> (types.size ()); - ++i) - idx(i) = types[i]->type_id (); - - if (types.size () == 1) - { - idx(1) = idx(0); - idx(0) = 0; - } - - return idx; -} - -// -------------------- jit_typeinfo -------------------- -void -jit_typeinfo::initialize (llvm::Module *m, llvm::ExecutionEngine *e) -{ - new jit_typeinfo (m, e); -} - -jit_typeinfo::jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e) - : module (m), engine (e), next_id (0) -{ - instance = this; - - // FIXME: We should be registering types like in octave_value_typeinfo - llvm::Type *any_t = llvm::StructType::create (context, "octave_base_value"); - any_t = any_t->getPointerTo (); - - llvm::Type *scalar_t = llvm::Type::getDoubleTy (context); - llvm::Type *bool_t = llvm::Type::getInt1Ty (context); - llvm::Type *string_t = llvm::Type::getInt8Ty (context); - string_t = string_t->getPointerTo (); - llvm::Type *index_t = llvm::Type::getIntNTy (context, - sizeof(octave_idx_type) * 8); - - llvm::StructType *range_t = llvm::StructType::create (context, "range"); - std::vector<llvm::Type *> range_contents (4, scalar_t); - range_contents[3] = index_t; - range_t->setBody (range_contents); - - llvm::Type *refcount_t = llvm::Type::getIntNTy (context, sizeof(int) * 8); - - llvm::StructType *matrix_t = llvm::StructType::create (context, "matrix"); - llvm::Type *matrix_contents[5]; - matrix_contents[0] = refcount_t->getPointerTo (); - matrix_contents[1] = scalar_t->getPointerTo (); - matrix_contents[2] = index_t; - matrix_contents[3] = index_t->getPointerTo (); - matrix_contents[4] = string_t; - matrix_t->setBody (llvm::makeArrayRef (matrix_contents, 5)); - - llvm::Type *complex_t = llvm::VectorType::get (scalar_t, 2); - - // this is the structure that C functions return. Use this in order to get calling - // conventions right. - complex_ret = llvm::StructType::create (context, "complex_ret"); - llvm::Type *complex_ret_contents[] = {scalar_t, scalar_t}; - complex_ret->setBody (complex_ret_contents); - - // create types - any = new_type ("any", 0, any_t); - matrix = new_type ("matrix", any, matrix_t); - complex = new_type ("complex", any, complex_t); - scalar = new_type ("scalar", complex, scalar_t); - range = new_type ("range", any, range_t); - string = new_type ("string", any, string_t); - boolean = new_type ("bool", any, bool_t); - index = new_type ("index", any, index_t); - - create_int (8); - create_int (16); - create_int (32); - create_int (64); - - casts.resize (next_id + 1); - identities.resize (next_id + 1); - - // specify calling conventions - // FIXME: We should detect architecture and do something sane based on that - // here we assume x86 or x86_64 - matrix->mark_sret (); - matrix->mark_pointer_arg (); - - range->mark_sret (); - range->mark_pointer_arg (); - - complex->set_pack (jit_convention::external, &jit_typeinfo::pack_complex); - complex->set_unpack (jit_convention::external, &jit_typeinfo::unpack_complex); - complex->set_packed_type (jit_convention::external, complex_ret); - - if (sizeof (void *) == 4) - complex->mark_sret (); - - // bind global variables - lerror_state = new llvm::GlobalVariable (*module, bool_t, false, - llvm::GlobalValue::ExternalLinkage, - 0, "error_state"); - engine->addGlobalMapping (lerror_state, - reinterpret_cast<void *> (&error_state)); - - // any with anything is an any op - jit_function fn; - jit_type *binary_op_type = intN (sizeof (octave_value::binary_op) * 8); - llvm::Type *llvm_bo_type = binary_op_type->to_llvm (); - jit_function any_binary = create_function (jit_convention::external, - "octave_jit_binary_any_any", - any, binary_op_type, any, any); - any_binary.mark_can_error (); - binary_ops.resize (octave_value::num_binary_ops); - for (size_t i = 0; i < octave_value::num_binary_ops; ++i) - { - octave_value::binary_op op = static_cast<octave_value::binary_op> (i); - std::string op_name = octave_value::binary_op_as_string (op); - binary_ops[i].stash_name ("binary" + op_name); - } - - for (int op = 0; op < octave_value::num_binary_ops; ++op) - { - llvm::Twine fn_name ("octave_jit_binary_any_any_"); - fn_name = fn_name + llvm::Twine (op); - - fn = create_function (jit_convention::internal, fn_name, any, any, any); - fn.mark_can_error (); - llvm::BasicBlock *block = fn.new_block (); - builder.SetInsertPoint (block); - llvm::APInt op_int(sizeof (octave_value::binary_op) * 8, op, - std::numeric_limits<octave_value::binary_op>::is_signed); - llvm::Value *op_as_llvm = llvm::ConstantInt::get (llvm_bo_type, op_int); - llvm::Value *ret = any_binary.call (op_as_llvm, fn.argument (0), - fn.argument (1)); - fn.do_return (ret); - binary_ops[op].add_overload (fn); - } - - // grab any - fn = create_function (jit_convention::external, "octave_jit_grab_any", any, - any); - grab_fn.add_overload (fn); - grab_fn.stash_name ("grab"); - - // grab matrix - fn = create_function (jit_convention::external, "octave_jit_grab_matrix", - matrix, matrix); - grab_fn.add_overload (fn); - - // release any - fn = create_function (jit_convention::external, "octave_jit_release_any", 0, - any); - release_fn.add_overload (fn); - release_fn.stash_name ("release"); - - // release matrix - fn = create_function (jit_convention::external, "octave_jit_release_matrix", - 0, matrix); - release_fn.add_overload (fn); - - // release scalar - fn = create_identity (scalar); - release_fn.add_overload (fn); - - // release complex - fn = create_identity (complex); - release_fn.add_overload (fn); - - // release index - fn = create_identity (index); - release_fn.add_overload (fn); - - // now for binary scalar operations - // FIXME: Finish all operations - add_binary_op (scalar, octave_value::op_add, llvm::Instruction::FAdd); - add_binary_op (scalar, octave_value::op_sub, llvm::Instruction::FSub); - add_binary_op (scalar, octave_value::op_mul, llvm::Instruction::FMul); - add_binary_op (scalar, octave_value::op_el_mul, llvm::Instruction::FMul); - - add_binary_fcmp (scalar, octave_value::op_lt, llvm::CmpInst::FCMP_ULT); - add_binary_fcmp (scalar, octave_value::op_le, llvm::CmpInst::FCMP_ULE); - add_binary_fcmp (scalar, octave_value::op_eq, llvm::CmpInst::FCMP_UEQ); - add_binary_fcmp (scalar, octave_value::op_ge, llvm::CmpInst::FCMP_UGE); - add_binary_fcmp (scalar, octave_value::op_gt, llvm::CmpInst::FCMP_UGT); - add_binary_fcmp (scalar, octave_value::op_ne, llvm::CmpInst::FCMP_UNE); - - jit_function gripe_div0 = create_function (jit_convention::external, - "gripe_divide_by_zero", 0); - gripe_div0.mark_can_error (); - - // divide is annoying because it might error - fn = create_function (jit_convention::internal, - "octave_jit_div_scalar_scalar", scalar, scalar, scalar); - fn.mark_can_error (); - - llvm::BasicBlock *body = fn.new_block (); - builder.SetInsertPoint (body); - { - llvm::BasicBlock *warn_block = fn.new_block ("warn"); - llvm::BasicBlock *normal_block = fn.new_block ("normal"); - - llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0); - llvm::Value *check = builder.CreateFCmpUEQ (zero, fn.argument (0)); - builder.CreateCondBr (check, warn_block, normal_block); - - builder.SetInsertPoint (warn_block); - gripe_div0.call (); - builder.CreateBr (normal_block); - - builder.SetInsertPoint (normal_block); - llvm::Value *ret = builder.CreateFDiv (fn.argument (0), - fn.argument (1)); - fn.do_return (ret); - } - binary_ops[octave_value::op_div].add_overload (fn); - binary_ops[octave_value::op_el_div].add_overload (fn); - - // ldiv is the same as div with the operators reversed - fn = mirror_binary (fn); - binary_ops[octave_value::op_ldiv].add_overload (fn); - binary_ops[octave_value::op_el_ldiv].add_overload (fn); - - // In general, the result of scalar ^ scalar is a complex number. We might be - // able to improve on this if we keep track of the range of values varaibles - // can take on. - fn = create_function (jit_convention::external, - "octave_jit_pow_scalar_scalar", complex, scalar, - scalar); - binary_ops[octave_value::op_pow].add_overload (fn); - binary_ops[octave_value::op_el_pow].add_overload (fn); - - // now for binary complex operations - add_binary_op (complex, octave_value::op_add, llvm::Instruction::FAdd); - add_binary_op (complex, octave_value::op_sub, llvm::Instruction::FSub); - - fn = create_function (jit_convention::internal, - "octave_jit_*_complex_complex", complex, complex, - complex); - body = fn.new_block (); - builder.SetInsertPoint (body); - { - // (x0*x1 - y0*y1, x0*y1 + y0*x1) = (x0,y0) * (x1,y1) - // We compute this in one vectorized multiplication, a subtraction, and an - // addition. - llvm::Value *lhs = fn.argument (0); - llvm::Value *rhs = fn.argument (1); - - // FIXME: We need a better way of doing this, working with llvm's IR - // directly is sort of a pain. - llvm::Value *zero = builder.getInt32 (0); - llvm::Value *one = builder.getInt32 (1); - llvm::Value *two = builder.getInt32 (2); - llvm::Value *three = builder.getInt32 (3); - - llvm::Type *vec4 = llvm::VectorType::get (scalar_t, 4); - llvm::Value *mlhs = llvm::UndefValue::get (vec4); - llvm::Value *mrhs = mlhs; - - llvm::Value *temp = complex_real (lhs); - mlhs = builder.CreateInsertElement (mlhs, temp, zero); - mlhs = builder.CreateInsertElement (mlhs, temp, two); - temp = complex_imag (lhs); - mlhs = builder.CreateInsertElement (mlhs, temp, one); - mlhs = builder.CreateInsertElement (mlhs, temp, three); - - temp = complex_real (rhs); - mrhs = builder.CreateInsertElement (mrhs, temp, zero); - mrhs = builder.CreateInsertElement (mrhs, temp, three); - temp = complex_imag (rhs); - mrhs = builder.CreateInsertElement (mrhs, temp, one); - mrhs = builder.CreateInsertElement (mrhs, temp, two); - - llvm::Value *mres = builder.CreateFMul (mlhs, mrhs); - llvm::Value *tlhs = builder.CreateExtractElement (mres, zero); - llvm::Value *trhs = builder.CreateExtractElement (mres, one); - llvm::Value *ret_real = builder.CreateFSub (tlhs, trhs); - - tlhs = builder.CreateExtractElement (mres, two); - trhs = builder.CreateExtractElement (mres, three); - llvm::Value *ret_imag = builder.CreateFAdd (tlhs, trhs); - fn.do_return (complex_new (ret_real, ret_imag)); - } - - binary_ops[octave_value::op_mul].add_overload (fn); - binary_ops[octave_value::op_el_mul].add_overload (fn); - - jit_function complex_div = create_function (jit_convention::external, - "octave_jit_complex_div", - complex, complex, complex); - complex_div.mark_can_error (); - binary_ops[octave_value::op_div].add_overload (fn); - binary_ops[octave_value::op_ldiv].add_overload (fn); - - fn = mirror_binary (complex_div); - binary_ops[octave_value::op_ldiv].add_overload (fn); - binary_ops[octave_value::op_el_ldiv].add_overload (fn); - - fn = create_function (jit_convention::external, - "octave_jit_pow_complex_complex", complex, complex, - complex); - binary_ops[octave_value::op_pow].add_overload (fn); - binary_ops[octave_value::op_el_pow].add_overload (fn); - - fn = create_function (jit_convention::internal, - "octave_jit_*_scalar_complex", complex, scalar, - complex); - jit_function mul_scalar_complex = fn; - body = fn.new_block (); - builder.SetInsertPoint (body); - { - llvm::Value *lhs = fn.argument (0); - llvm::Value *tlhs = complex_new (lhs, lhs); - llvm::Value *rhs = fn.argument (1); - fn.do_return (builder.CreateFMul (tlhs, rhs)); - } - binary_ops[octave_value::op_mul].add_overload (fn); - binary_ops[octave_value::op_el_mul].add_overload (fn); - - - fn = mirror_binary (mul_scalar_complex); - binary_ops[octave_value::op_mul].add_overload (fn); - binary_ops[octave_value::op_el_mul].add_overload (fn); - - fn = create_function (jit_convention::internal, "octave_jit_+_scalar_complex", - complex, scalar, complex); - body = fn.new_block (); - builder.SetInsertPoint (body); - { - llvm::Value *lhs = fn.argument (0); - llvm::Value *rhs = fn.argument (1); - llvm::Value *real = builder.CreateFAdd (lhs, complex_real (rhs)); - fn.do_return (complex_real (rhs, real)); - } - binary_ops[octave_value::op_add].add_overload (fn); - - fn = mirror_binary (fn); - binary_ops[octave_value::op_add].add_overload (fn); - - fn = create_function (jit_convention::internal, "octave_jit_-_complex_scalar", - complex, complex, scalar); - body = fn.new_block (); - builder.SetInsertPoint (body); - { - llvm::Value *lhs = fn.argument (0); - llvm::Value *rhs = fn.argument (1); - llvm::Value *real = builder.CreateFSub (complex_real (lhs), rhs); - fn.do_return (complex_real (lhs, real)); - } - binary_ops[octave_value::op_sub].add_overload (fn); - - fn = create_function (jit_convention::internal, "octave_jit_-_scalar_complex", - complex, scalar, complex); - body = fn.new_block (); - builder.SetInsertPoint (body); - { - llvm::Value *lhs = fn.argument (0); - llvm::Value *rhs = fn.argument (1); - llvm::Value *real = builder.CreateFSub (lhs, complex_real (rhs)); - fn.do_return (complex_real (rhs, real)); - } - binary_ops[octave_value::op_sub].add_overload (fn); - - fn = create_function (jit_convention::external, - "octave_jit_pow_scalar_complex", complex, scalar, - complex); - binary_ops[octave_value::op_pow].add_overload (fn); - binary_ops[octave_value::op_el_pow].add_overload (fn); - - fn = create_function (jit_convention::external, - "octave_jit_pow_complex_scalar", complex, complex, - scalar); - binary_ops[octave_value::op_pow].add_overload (fn); - binary_ops[octave_value::op_el_pow].add_overload (fn); - - // now for binary index operators - add_binary_op (index, octave_value::op_add, llvm::Instruction::Add); - - // and binary bool operators - add_binary_op (boolean, octave_value::op_el_or, llvm::Instruction::Or); - add_binary_op (boolean, octave_value::op_el_and, llvm::Instruction::And); - - // now for printing functions - print_fn.stash_name ("print"); - add_print (any); - add_print (scalar); - - // initialize for loop - for_init_fn.stash_name ("for_init"); - - fn = create_function (jit_convention::internal, "octave_jit_for_range_init", - index, range); - body = fn.new_block (); - builder.SetInsertPoint (body); - { - llvm::Value *zero = llvm::ConstantInt::get (index_t, 0); - fn.do_return (zero); - } - for_init_fn.add_overload (fn); - - // bounds check for for loop - for_check_fn.stash_name ("for_check"); - - fn = create_function (jit_convention::internal, "octave_jit_for_range_check", - boolean, range, index); - body = fn.new_block (); - builder.SetInsertPoint (body); - { - llvm::Value *nelem - = builder.CreateExtractValue (fn.argument (0), 3); - llvm::Value *idx = fn.argument (1); - llvm::Value *ret = builder.CreateICmpULT (idx, nelem); - fn.do_return (ret); - } - for_check_fn.add_overload (fn); - - // index variabe for for loop - for_index_fn.stash_name ("for_index"); - - fn = create_function (jit_convention::internal, "octave_jit_for_range_idx", - scalar, range, index); - body = fn.new_block (); - builder.SetInsertPoint (body); - { - llvm::Value *idx = fn.argument (1); - llvm::Value *didx = builder.CreateSIToFP (idx, scalar_t); - llvm::Value *rng = fn.argument (0); - llvm::Value *base = builder.CreateExtractValue (rng, 0); - llvm::Value *inc = builder.CreateExtractValue (rng, 2); - - llvm::Value *ret = builder.CreateFMul (didx, inc); - ret = builder.CreateFAdd (base, ret); - fn.do_return (ret); - } - for_index_fn.add_overload (fn); - - // logically true - logically_true_fn.stash_name ("logically_true"); - - jit_function gripe_nantl - = create_function (jit_convention::external, - "octave_jit_gripe_nan_to_logical_conversion", 0); - gripe_nantl.mark_can_error (); - - fn = create_function (jit_convention::internal, - "octave_jit_logically_true_scalar", boolean, scalar); - fn.mark_can_error (); - - body = fn.new_block (); - builder.SetInsertPoint (body); - { - llvm::BasicBlock *error_block = fn.new_block ("error"); - llvm::BasicBlock *normal_block = fn.new_block ("normal"); - - llvm::Value *check = builder.CreateFCmpUNE (fn.argument (0), - fn.argument (0)); - builder.CreateCondBr (check, error_block, normal_block); - - builder.SetInsertPoint (error_block); - gripe_nantl.call (); - builder.CreateBr (normal_block); - builder.SetInsertPoint (normal_block); - - llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0); - llvm::Value *ret = builder.CreateFCmpONE (fn.argument (0), zero); - fn.do_return (ret); - } - logically_true_fn.add_overload (fn); - - // logically_true boolean - fn = create_identity (boolean); - logically_true_fn.add_overload (fn); - - // make_range - // FIXME: May be benificial to implement all in LLVM - make_range_fn.stash_name ("make_range"); - jit_function compute_nelem - = create_function (jit_convention::external, "octave_jit_compute_nelem", - index, scalar, scalar, scalar); - - fn = create_function (jit_convention::internal, "octave_jit_make_range", - range, scalar, scalar, scalar); - body = fn.new_block (); - builder.SetInsertPoint (body); - { - llvm::Value *base = fn.argument (0); - llvm::Value *limit = fn.argument (1); - llvm::Value *inc = fn.argument (2); - llvm::Value *nelem = compute_nelem.call (base, limit, inc); - - llvm::Value *dzero = llvm::ConstantFP::get (scalar_t, 0); - llvm::Value *izero = llvm::ConstantInt::get (index_t, 0); - llvm::Value *rng = llvm::ConstantStruct::get (range_t, dzero, dzero, dzero, - izero, NULL); - rng = builder.CreateInsertValue (rng, base, 0); - rng = builder.CreateInsertValue (rng, limit, 1); - rng = builder.CreateInsertValue (rng, inc, 2); - rng = builder.CreateInsertValue (rng, nelem, 3); - fn.do_return (rng); - } - make_range_fn.add_overload (fn); - - // paren_subsref - jit_type *jit_int = intN (sizeof (int) * 8); - llvm::Type *int_t = jit_int->to_llvm (); - jit_function ginvalid_index - = create_function (jit_convention::external, "octave_jit_ginvalid_index", - 0); - jit_function gindex_range = create_function (jit_convention::external, - "octave_jit_gindex_range", - 0, jit_int, jit_int, index, - index); - - fn = create_function (jit_convention::internal, "()subsref", scalar, matrix, - scalar); - fn.mark_can_error (); - - body = fn.new_block (); - builder.SetInsertPoint (body); - { - llvm::Value *one = llvm::ConstantInt::get (index_t, 1); - llvm::Value *ione; - if (index_t == int_t) - ione = one; - else - ione = llvm::ConstantInt::get (int_t, 1); - - llvm::Value *undef = llvm::UndefValue::get (scalar_t); - llvm::Value *mat = fn.argument (0); - llvm::Value *idx = fn.argument (1); - - // convert index to scalar to integer, and check index >= 1 - llvm::Value *int_idx = builder.CreateFPToSI (idx, index_t); - llvm::Value *check_idx = builder.CreateSIToFP (int_idx, scalar_t); - llvm::Value *cond0 = builder.CreateFCmpUNE (idx, check_idx); - llvm::Value *cond1 = builder.CreateICmpSLT (int_idx, one); - llvm::Value *cond = builder.CreateOr (cond0, cond1); - - llvm::BasicBlock *done = fn.new_block ("done"); - llvm::BasicBlock *conv_error = fn.new_block ("conv_error", done); - llvm::BasicBlock *normal = fn.new_block ("normal", done); - builder.CreateCondBr (cond, conv_error, normal); - - builder.SetInsertPoint (conv_error); - ginvalid_index.call (); - builder.CreateBr (done); - - builder.SetInsertPoint (normal); - llvm::Value *len = builder.CreateExtractValue (mat, - llvm::ArrayRef<unsigned> (2)); - cond = builder.CreateICmpSGT (int_idx, len); - - - llvm::BasicBlock *bounds_error = fn.new_block ("bounds_error", done); - llvm::BasicBlock *success = fn.new_block ("success", done); - builder.CreateCondBr (cond, bounds_error, success); - - builder.SetInsertPoint (bounds_error); - gindex_range.call (ione, ione, int_idx, len); - builder.CreateBr (done); - - builder.SetInsertPoint (success); - llvm::Value *data = builder.CreateExtractValue (mat, - llvm::ArrayRef<unsigned> (1)); - llvm::Value *gep = builder.CreateInBoundsGEP (data, int_idx); - llvm::Value *ret = builder.CreateLoad (gep); - builder.CreateBr (done); - - builder.SetInsertPoint (done); - - llvm::PHINode *merge = llvm::PHINode::Create (scalar_t, 3); - builder.Insert (merge); - merge->addIncoming (undef, conv_error); - merge->addIncoming (undef, bounds_error); - merge->addIncoming (ret, success); - fn.do_return (merge); - } - paren_subsref_fn.add_overload (fn); - - // paren subsasgn - paren_subsasgn_fn.stash_name ("()subsasgn"); - - jit_function resize_paren_subsasgn - = create_function (jit_convention::external, - "octave_jit_paren_subsasgn_impl", matrix, index, scalar); - fn = create_function (jit_convention::internal, "octave_jit_paren_subsasgn", - matrix, matrix, scalar, scalar); - fn.mark_can_error (); - body = fn.new_block (); - builder.SetInsertPoint (body); - { - llvm::Value *one = llvm::ConstantInt::get (index_t, 1); - - llvm::Value *mat = fn.argument (0); - llvm::Value *idx = fn.argument (1); - llvm::Value *value = fn.argument (2); - - llvm::Value *int_idx = builder.CreateFPToSI (idx, index_t); - llvm::Value *check_idx = builder.CreateSIToFP (int_idx, scalar_t); - llvm::Value *cond0 = builder.CreateFCmpUNE (idx, check_idx); - llvm::Value *cond1 = builder.CreateICmpSLT (int_idx, one); - llvm::Value *cond = builder.CreateOr (cond0, cond1); - - llvm::BasicBlock *done = fn.new_block ("done"); - - llvm::BasicBlock *conv_error = fn.new_block ("conv_error", done); - llvm::BasicBlock *normal = fn.new_block ("normal", done); - builder.CreateCondBr (cond, conv_error, normal); - builder.SetInsertPoint (conv_error); - ginvalid_index.call (); - builder.CreateBr (done); - - builder.SetInsertPoint (normal); - llvm::Value *len = builder.CreateExtractValue (mat, - llvm::ArrayRef<unsigned> (2)); - cond0 = builder.CreateICmpSGT (int_idx, len); - - llvm::Value *rcount = builder.CreateExtractValue (mat, 0); - rcount = builder.CreateLoad (rcount); - cond1 = builder.CreateICmpSGT (rcount, one); - cond = builder.CreateOr (cond0, cond1); - - llvm::BasicBlock *bounds_error = fn.new_block ("bounds_error", done); - llvm::BasicBlock *success = fn.new_block ("success", done); - builder.CreateCondBr (cond, bounds_error, success); - - // resize on out of bounds access - builder.SetInsertPoint (bounds_error); - llvm::Value *resize_result = resize_paren_subsasgn.call (int_idx, value); - builder.CreateBr (done); - - builder.SetInsertPoint (success); - llvm::Value *data = builder.CreateExtractValue (mat, - llvm::ArrayRef<unsigned> (1)); - llvm::Value *gep = builder.CreateInBoundsGEP (data, int_idx); - builder.CreateStore (value, gep); - builder.CreateBr (done); - - builder.SetInsertPoint (done); - - llvm::PHINode *merge = llvm::PHINode::Create (matrix_t, 3); - builder.Insert (merge); - merge->addIncoming (mat, conv_error); - merge->addIncoming (resize_result, bounds_error); - merge->addIncoming (mat, success); - fn.do_return (merge); - } - paren_subsasgn_fn.add_overload (fn); - - fn = create_function (jit_convention::external, - "octave_jit_paren_subsasgn_matrix_range", matrix, - matrix, range, scalar); - fn.mark_can_error (); - paren_subsasgn_fn.add_overload (fn); - - casts[any->type_id ()].stash_name ("(any)"); - casts[scalar->type_id ()].stash_name ("(scalar)"); - casts[complex->type_id ()].stash_name ("(complex)"); - casts[matrix->type_id ()].stash_name ("(matrix)"); - - // cast any <- matrix - fn = create_function (jit_convention::external, "octave_jit_cast_any_matrix", - any, matrix); - casts[any->type_id ()].add_overload (fn); - - // cast matrix <- any - fn = create_function (jit_convention::external, "octave_jit_cast_matrix_any", - matrix, any); - casts[matrix->type_id ()].add_overload (fn); - - // cast any <- scalar - fn = create_function (jit_convention::external, "octave_jit_cast_any_scalar", - any, scalar); - casts[any->type_id ()].add_overload (fn); - - // cast scalar <- any - fn = create_function (jit_convention::external, "octave_jit_cast_scalar_any", - scalar, any); - casts[scalar->type_id ()].add_overload (fn); - - // cast any <- complex - fn = create_function (jit_convention::external, "octave_jit_cast_any_complex", - any, complex); - casts[any->type_id ()].add_overload (fn); - - // cast complex <- any - fn = create_function (jit_convention::external, "octave_jit_cast_complex_any", - complex, any); - casts[complex->type_id ()].add_overload (fn); - - // cast complex <- scalar - fn = create_function (jit_convention::internal, - "octave_jit_cast_complex_scalar", complex, scalar); - body = fn.new_block (); - builder.SetInsertPoint (body); - { - llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0); - fn.do_return (complex_new (fn.argument (0), zero)); - } - casts[complex->type_id ()].add_overload (fn); - - // cast scalar <- complex - fn = create_function (jit_convention::internal, - "octave_jit_cast_scalar_complex", scalar, complex); - body = fn.new_block (); - builder.SetInsertPoint (body); - fn.do_return (complex_real (fn.argument (0))); - casts[scalar->type_id ()].add_overload (fn); - - // cast any <- any - fn = create_identity (any); - casts[any->type_id ()].add_overload (fn); - - // cast scalar <- scalar - fn = create_identity (scalar); - casts[scalar->type_id ()].add_overload (fn); - - // cast complex <- complex - fn = create_identity (complex); - casts[complex->type_id ()].add_overload (fn); - - // -------------------- builtin functions -------------------- - add_builtin ("#unknown_function"); - unknown_function = builtins["#unknown_function"]; - - add_builtin ("sin"); - register_intrinsic ("sin", llvm::Intrinsic::sin, scalar, scalar); - register_generic ("sin", matrix, matrix); - - add_builtin ("cos"); - register_intrinsic ("cos", llvm::Intrinsic::cos, scalar, scalar); - register_generic ("cos", matrix, matrix); - - add_builtin ("exp"); - register_intrinsic ("exp", llvm::Intrinsic::cos, scalar, scalar); - register_generic ("exp", matrix, matrix); - - casts.resize (next_id + 1); - jit_function any_id = create_identity (any); - jit_function release_any = get_release (any); - std::vector<jit_type *> args; - args.resize (1); - - for (std::map<std::string, jit_type *>::iterator iter = builtins.begin (); - iter != builtins.end (); ++iter) - { - jit_type *btype = iter->second; - args[0] = btype; - - release_fn.add_overload (jit_function (release_any, 0, args)); - casts[any->type_id ()].add_overload (jit_function (any_id, any, args)); - - args[0] = any; - casts[btype->type_id ()].add_overload (jit_function (any_id, btype, - args)); - } -} - -void -jit_typeinfo::add_print (jit_type *ty) -{ - std::stringstream name; - name << "octave_jit_print_" << ty->name (); - jit_function fn = create_function (jit_convention::external, name.str (), 0, - intN (8), ty); - print_fn.add_overload (fn); -} - -// FIXME: cp between add_binary_op, add_binary_icmp, and add_binary_fcmp -void -jit_typeinfo::add_binary_op (jit_type *ty, int op, int llvm_op) -{ - std::stringstream fname; - octave_value::binary_op ov_op = static_cast<octave_value::binary_op>(op); - fname << "octave_jit_" << octave_value::binary_op_as_string (ov_op) - << "_" << ty->name (); - - jit_function fn = create_function (jit_convention::internal, fname.str (), - ty, ty, ty); - llvm::BasicBlock *block = fn.new_block (); - builder.SetInsertPoint (block); - llvm::Instruction::BinaryOps temp - = static_cast<llvm::Instruction::BinaryOps>(llvm_op); - - llvm::Value *ret = builder.CreateBinOp (temp, fn.argument (0), - fn.argument (1)); - fn.do_return (ret); - binary_ops[op].add_overload (fn); -} - -void -jit_typeinfo::add_binary_icmp (jit_type *ty, int op, int llvm_op) -{ - std::stringstream fname; - octave_value::binary_op ov_op = static_cast<octave_value::binary_op>(op); - fname << "octave_jit" << octave_value::binary_op_as_string (ov_op) - << "_" << ty->name (); - - jit_function fn = create_function (jit_convention::internal, fname.str (), - boolean, ty, ty); - llvm::BasicBlock *block = fn.new_block (); - builder.SetInsertPoint (block); - llvm::CmpInst::Predicate temp - = static_cast<llvm::CmpInst::Predicate>(llvm_op); - llvm::Value *ret = builder.CreateICmp (temp, fn.argument (0), - fn.argument (1)); - fn.do_return (ret); - binary_ops[op].add_overload (fn); -} - -void -jit_typeinfo::add_binary_fcmp (jit_type *ty, int op, int llvm_op) -{ - std::stringstream fname; - octave_value::binary_op ov_op = static_cast<octave_value::binary_op>(op); - fname << "octave_jit" << octave_value::binary_op_as_string (ov_op) - << "_" << ty->name (); - - jit_function fn = create_function (jit_convention::internal, fname.str (), - boolean, ty, ty); - llvm::BasicBlock *block = fn.new_block (); - builder.SetInsertPoint (block); - llvm::CmpInst::Predicate temp - = static_cast<llvm::CmpInst::Predicate>(llvm_op); - llvm::Value *ret = builder.CreateFCmp (temp, fn.argument (0), - fn.argument (1)); - fn.do_return (ret); - binary_ops[op].add_overload (fn); -} - -jit_function -jit_typeinfo::create_function (jit_convention::type cc, const llvm::Twine& name, - jit_type *ret, - const std::vector<jit_type *>& args) -{ - jit_function result (module, cc, name, ret, args); - return result; -} - -jit_function -jit_typeinfo::create_identity (jit_type *type) -{ - size_t id = type->type_id (); - if (id >= identities.size ()) - identities.resize (id + 1); - - if (! identities[id].valid ()) - { - jit_function fn = create_function (jit_convention::internal, "id", type, - type); - llvm::BasicBlock *body = fn.new_block (); - builder.SetInsertPoint (body); - fn.do_return (fn.argument (0)); - return identities[id] = fn; - } - - return identities[id]; -} - -llvm::Value * -jit_typeinfo::do_insert_error_check (void) -{ - return builder.CreateLoad (lerror_state); -} - -void -jit_typeinfo::add_builtin (const std::string& name) -{ - jit_type *btype = new_type (name, any, any->to_llvm ()); - builtins[name] = btype; - - octave_builtin *ov_builtin = find_builtin (name); - if (ov_builtin) - ov_builtin->stash_jit (*btype); -} - -void -jit_typeinfo::register_intrinsic (const std::string& name, size_t iid, - jit_type *result, - const std::vector<jit_type *>& args) -{ - jit_type *builtin_type = builtins[name]; - size_t nargs = args.size (); - llvm::SmallVector<llvm::Type *, 5> llvm_args (nargs); - for (size_t i = 0; i < nargs; ++i) - llvm_args[i] = args[i]->to_llvm (); - - llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID> (iid); - llvm::Function *ifun = llvm::Intrinsic::getDeclaration (module, id, - llvm_args); - std::stringstream fn_name; - fn_name << "octave_jit_" << name; - - std::vector<jit_type *> args1 (nargs + 1); - args1[0] = builtin_type; - std::copy (args.begin (), args.end (), args1.begin () + 1); - - // The first argument will be the Octave function, but we already know that - // the function call is the equivalent of the intrinsic, so we ignore it and - // call the intrinsic with the remaining arguments. - jit_function fn = create_function (jit_convention::internal, fn_name.str (), - result, args1); - llvm::BasicBlock *body = fn.new_block (); - builder.SetInsertPoint (body); - - llvm::SmallVector<llvm::Value *, 5> fargs (nargs); - for (size_t i = 0; i < nargs; ++i) - fargs[i] = fn.argument (i + 1); - - llvm::Value *ret = builder.CreateCall (ifun, fargs); - fn.do_return (ret); - paren_subsref_fn.add_overload (fn); -} - -octave_builtin * -jit_typeinfo::find_builtin (const std::string& name) -{ - // FIXME: Finalize what we want to store in octave_builtin, then add functions - // to access these values in octave_value - octave_value ov_builtin = symbol_table::find (name); - return dynamic_cast<octave_builtin *> (ov_builtin.internal_rep ()); -} - -void -jit_typeinfo::register_generic (const std::string&, jit_type *, - const std::vector<jit_type *>&) -{ - // FIXME: Implement -} - -jit_function -jit_typeinfo::mirror_binary (const jit_function& fn) -{ - jit_function ret = create_function (jit_convention::internal, - fn.name () + "_reverse", - fn.result (), fn.argument_type (1), - fn.argument_type (0)); - if (fn.can_error ()) - ret.mark_can_error (); - - llvm::BasicBlock *body = ret.new_block (); - builder.SetInsertPoint (body); - llvm::Value *result = fn.call (ret.argument (1), ret.argument (0)); - if (ret.result ()) - ret.do_return (result); - else - ret.do_return (); - - return ret; -} - -llvm::Value * -jit_typeinfo::pack_complex (llvm::Value *cplx) -{ - llvm::Type *complex_ret = instance->complex_ret; - llvm::Value *real = builder.CreateExtractElement (cplx, builder.getInt32 (0)); - llvm::Value *imag = builder.CreateExtractElement (cplx, builder.getInt32 (1)); - llvm::Value *ret = llvm::UndefValue::get (complex_ret); - ret = builder.CreateInsertValue (ret, real, 0); - return builder.CreateInsertValue (ret, imag, 1); -} - -llvm::Value * -jit_typeinfo::unpack_complex (llvm::Value *result) -{ - llvm::Type *complex_t = get_complex ()->to_llvm (); - llvm::Value *real = builder.CreateExtractValue (result, 0); - llvm::Value *imag = builder.CreateExtractValue (result, 1); - llvm::Value *ret = llvm::UndefValue::get (complex_t); - ret = builder.CreateInsertElement (ret, real, builder.getInt32 (0)); - return builder.CreateInsertElement (ret, imag, builder.getInt32 (1)); -} - -llvm::Value * -jit_typeinfo::complex_real (llvm::Value *cx) -{ - return builder.CreateExtractElement (cx, builder.getInt32 (0)); -} - -llvm::Value * -jit_typeinfo::complex_real (llvm::Value *cx, llvm::Value *real) -{ - return builder.CreateInsertElement (cx, real, builder.getInt32 (0)); -} - -llvm::Value * -jit_typeinfo::complex_imag (llvm::Value *cx) -{ - return builder.CreateExtractElement (cx, builder.getInt32 (1)); -} - -llvm::Value * -jit_typeinfo::complex_imag (llvm::Value *cx, llvm::Value *imag) -{ - return builder.CreateInsertElement (cx, imag, builder.getInt32 (1)); -} - -llvm::Value * -jit_typeinfo::complex_new (llvm::Value *real, llvm::Value *imag) -{ - llvm::Value *ret = llvm::UndefValue::get (complex->to_llvm ()); - ret = complex_real (ret, real); - return complex_imag (ret, imag); -} - -void -jit_typeinfo::create_int (size_t nbits) -{ - std::stringstream tname; - tname << "int" << nbits; - ints[nbits] = new_type (tname.str (), any, llvm::Type::getIntNTy (context, - nbits)); -} - -jit_type * -jit_typeinfo::intN (size_t nbits) const -{ - std::map<size_t, jit_type *>::const_iterator iter = ints.find (nbits); - if (iter != ints.end ()) - return iter->second; - - fail ("No such integer type"); -} - -jit_type * -jit_typeinfo::do_type_of (const octave_value &ov) const -{ - if (ov.is_function ()) - { - // FIXME: This is ugly, we need to finalize how we want to to this, then - // have octave_value fully support the needed functionality - octave_builtin *builtin - = dynamic_cast<octave_builtin *> (ov.internal_rep ()); - return builtin && builtin->to_jit () ? builtin->to_jit () - : unknown_function; - } - - if (ov.is_range ()) - return get_range (); - - if (ov.is_double_type ()) - { - if (ov.is_real_scalar ()) - return get_scalar (); - - if (ov.is_matrix_type ()) - return get_matrix (); - } - - if (ov.is_complex_scalar ()) - return get_complex (); - - return get_any (); -} - -jit_type* -jit_typeinfo::new_type (const std::string& name, jit_type *parent, - llvm::Type *llvm_type) -{ - jit_type *ret = new jit_type (name, parent, llvm_type, next_id++); - id_to_type.push_back (ret); - return ret; -} - -// -------------------- jit_use -------------------- -jit_block * -jit_use::user_parent (void) const -{ - return muser->parent (); -} - -// -------------------- jit_value -------------------- -jit_value::~jit_value (void) -{} - -jit_block * -jit_value::first_use_block (void) -{ - jit_use *use = first_use (); - while (use) - { - if (! isa<jit_error_check> (use->user ())) - return use->user_parent (); - - use = use->next (); - } - - return 0; -} - -void -jit_value::replace_with (jit_value *value) -{ - while (first_use ()) - { - jit_instruction *user = first_use ()->user (); - size_t idx = first_use ()->index (); - user->stash_argument (idx, value); - } -} - -#define JIT_METH(clname) \ - void \ - jit_ ## clname::accept (jit_ir_walker& walker) \ - { \ - walker.visit (*this); \ - } - -JIT_VISIT_IR_NOTEMPLATE -#undef JIT_METH - -std::ostream& -operator<< (std::ostream& os, const jit_value& value) -{ - return value.short_print (os); -} - -std::ostream& -jit_print (std::ostream& os, jit_value *avalue) -{ - if (avalue) - return avalue->print (os); - return os << "NULL"; -} - -// -------------------- jit_instruction -------------------- -void -jit_instruction::remove (void) -{ - if (mparent) - mparent->remove (mlocation); - resize_arguments (0); -} - -llvm::BasicBlock * -jit_instruction::parent_llvm (void) const -{ - return mparent->to_llvm (); -} - -std::ostream& -jit_instruction::short_print (std::ostream& os) const -{ - if (type ()) - jit_print (os, type ()) << ": "; - return os << "#" << mid; -} - -void -jit_instruction::do_construct_ssa (size_t start, size_t end) -{ - for (size_t i = start; i < end; ++i) - { - jit_value *arg = argument (i); - jit_variable *var = dynamic_cast<jit_variable *> (arg); - if (var && var->has_top ()) - stash_argument (i, var->top ()); - } -} - -// -------------------- jit_block -------------------- -void -jit_block::replace_with (jit_value *value) -{ - assert (isa<jit_block> (value)); - jit_block *block = static_cast<jit_block *> (value); - - jit_value::replace_with (block); - - while (ILIST_T::first_use ()) - { - jit_phi_incomming *incomming = ILIST_T::first_use (); - incomming->stash_value (block); - } -} - -void -jit_block::replace_in_phi (jit_block *ablock, jit_block *with) -{ - jit_phi_incomming *node = ILIST_T::first_use (); - while (node) - { - jit_phi_incomming *prev = node; - node = node->next (); - - if (prev->user_parent () == ablock) - prev->stash_value (with); - } -} - -jit_block * -jit_block::maybe_merge () -{ - if (successor_count () == 1 && successor (0) != this - && (successor (0)->use_count () == 1 || instructions.size () == 1)) - { - jit_block *to_merge = successor (0); - merge (*to_merge); - return to_merge; - } - - return 0; -} - -void -jit_block::merge (jit_block& block) -{ - // the merge block will contain a new terminator - jit_terminator *old_term = terminator (); - if (old_term) - old_term->remove (); - - bool was_empty = end () == begin (); - iterator merge_begin = end (); - if (! was_empty) - --merge_begin; - - instructions.splice (end (), block.instructions); - if (was_empty) - merge_begin = begin (); - else - ++merge_begin; - - // now merge_begin points to the start of the new instructions, we must - // update their parent information - for (iterator iter = merge_begin; iter != end (); ++iter) - { - jit_instruction *instr = *iter; - instr->stash_parent (this, iter); - } - - block.replace_with (this); -} - -jit_instruction * -jit_block::prepend (jit_instruction *instr) -{ - instructions.push_front (instr); - instr->stash_parent (this, instructions.begin ()); - return instr; -} - -jit_instruction * -jit_block::prepend_after_phi (jit_instruction *instr) -{ - // FIXME: Make this O(1) - for (iterator iter = begin (); iter != end (); ++iter) - { - jit_instruction *temp = *iter; - if (! isa<jit_phi> (temp)) - { - insert_before (iter, instr); - return instr; - } - } - - return append (instr); -} - -void -jit_block::internal_append (jit_instruction *instr) -{ - instructions.push_back (instr); - instr->stash_parent (this, --instructions.end ()); -} - -jit_instruction * -jit_block::insert_before (iterator loc, jit_instruction *instr) -{ - iterator iloc = instructions.insert (loc, instr); - instr->stash_parent (this, iloc); - return instr; -} - -jit_instruction * -jit_block::insert_after (iterator loc, jit_instruction *instr) -{ - ++loc; - iterator iloc = instructions.insert (loc, instr); - instr->stash_parent (this, iloc); - return instr; -} - -jit_terminator * -jit_block::terminator (void) const -{ - assert (this); - if (instructions.empty ()) - return 0; - - jit_instruction *last = instructions.back (); - return dynamic_cast<jit_terminator *> (last); -} - -bool -jit_block::branch_alive (jit_block *asucc) const -{ - return terminator ()->alive (asucc); -} - -jit_block * -jit_block::successor (size_t i) const -{ - jit_terminator *term = terminator (); - return term->successor (i); -} - -size_t -jit_block::successor_count (void) const -{ - jit_terminator *term = terminator (); - return term ? term->successor_count () : 0; -} - -llvm::BasicBlock * -jit_block::to_llvm (void) const -{ - return llvm::cast<llvm::BasicBlock> (llvm_value); -} - -std::ostream& -jit_block::print_dom (std::ostream& os) const -{ - short_print (os); - os << ":\n"; - os << " mid: " << mid << std::endl; - os << " predecessors: "; - for (jit_use *use = first_use (); use; use = use->next ()) - os << *use->user_parent () << " "; - os << std::endl; - - os << " successors: "; - for (size_t i = 0; i < successor_count (); ++i) - os << *successor (i) << " "; - os << std::endl; - - os << " idom: "; - if (idom) - os << *idom; - else - os << "NULL"; - os << std::endl; - os << " df: "; - for (df_iterator iter = df_begin (); iter != df_end (); ++iter) - os << **iter << " "; - os << std::endl; - - os << " dom_succ: "; - for (size_t i = 0; i < dom_succ.size (); ++i) - os << *dom_succ[i] << " "; - - return os << std::endl; -} - -void -jit_block::compute_df (size_t avisit_count) -{ - if (visited (avisit_count)) - return; - - if (use_count () >= 2) - { - for (jit_use *use = first_use (); use; use = use->next ()) - { - jit_block *runner = use->user_parent (); - while (runner != idom) - { - runner->mdf.insert (this); - runner = runner->idom; - } - } - } - - for (size_t i = 0; i < successor_count (); ++i) - successor (i)->compute_df (avisit_count); -} - -bool -jit_block::update_idom (size_t avisit_count) -{ - if (visited (avisit_count) || ! use_count ()) - return false; - - bool changed = false; - for (jit_use *use = first_use (); use; use = use->next ()) - { - jit_block *pred = use->user_parent (); - changed = pred->update_idom (avisit_count) || changed; - } - - jit_use *use = first_use (); - jit_block *new_idom = use->user_parent (); - use = use->next (); - - for (; use; use = use->next ()) - { - jit_block *pred = use->user_parent (); - jit_block *pidom = pred->idom; - if (pidom) - new_idom = idom_intersect (pidom, new_idom); - } - - if (idom != new_idom) - { - idom = new_idom; - return true; - } - - return changed; -} - -void -jit_block::pop_all (void) -{ - for (iterator iter = begin (); iter != end (); ++iter) - { - jit_instruction *instr = *iter; - instr->pop_variable (); - } -} - -jit_block * -jit_block::maybe_split (jit_convert& convert, jit_block *asuccessor) -{ - if (successor_count () > 1) - { - jit_terminator *term = terminator (); - size_t idx = term->successor_index (asuccessor); - jit_block *split = convert.create<jit_block> ("phi_split", mvisit_count); - - // try to place splits where they make sense - if (id () < asuccessor->id ()) - convert.insert_before (asuccessor, split); - else - convert.insert_after (this, split); - - term->stash_argument (idx, split); - jit_branch *br = split->append (convert.create<jit_branch> (asuccessor)); - replace_in_phi (asuccessor, split); - - if (alive ()) - { - split->mark_alive (); - br->infer (); - } - - return split; - } - - return this; -} - -void -jit_block::create_dom_tree (size_t avisit_count) -{ - if (visited (avisit_count)) - return; - - if (idom != this) - idom->dom_succ.push_back (this); - - for (size_t i = 0; i < successor_count (); ++i) - successor (i)->create_dom_tree (avisit_count); -} - -jit_block * -jit_block::idom_intersect (jit_block *i, jit_block *j) -{ - while (i && j && i != j) - { - while (i && i->id () > j->id ()) - i = i->idom; - - while (i && j && j->id () > i->id ()) - j = j->idom; - } - - return i ? i : j; -} - -// -------------------- jit_phi_incomming -------------------- - -jit_block * -jit_phi_incomming::user_parent (void) const -{ return muser->parent (); } - -// -------------------- jit_phi -------------------- -bool -jit_phi::prune (void) -{ - jit_block *p = parent (); - size_t new_idx = 0; - jit_value *unique = argument (1); - - for (size_t i = 0; i < argument_count (); ++i) - { - jit_block *inc = incomming (i); - if (inc->branch_alive (p)) - { - if (unique != argument (i)) - unique = 0; - - if (new_idx != i) - { - stash_argument (new_idx, argument (i)); - mincomming[new_idx].stash_value (inc); - } - - ++new_idx; - } - } - - if (new_idx != argument_count ()) - { - resize_arguments (new_idx); - mincomming.resize (new_idx); - } - - assert (argument_count () > 0); - if (unique) - { - replace_with (unique); - return true; - } - - return false; -} - -bool -jit_phi::infer (void) -{ - jit_block *p = parent (); - if (! p->alive ()) - return false; - - jit_type *infered = 0; - for (size_t i = 0; i < argument_count (); ++i) - { - jit_block *inc = incomming (i); - if (inc->branch_alive (p)) - infered = jit_typeinfo::join (infered, argument_type (i)); - } - - if (infered != type ()) - { - stash_type (infered); - return true; - } - - return false; -} - -llvm::PHINode * -jit_phi::to_llvm (void) const -{ - return llvm::cast<llvm::PHINode> (jit_value::to_llvm ()); -} - -// -------------------- jit_terminator -------------------- -size_t -jit_terminator::successor_index (const jit_block *asuccessor) const -{ - size_t scount = successor_count (); - for (size_t i = 0; i < scount; ++i) - if (successor (i) == asuccessor) - return i; - - panic_impossible (); -} - -bool -jit_terminator::infer (void) -{ - if (! parent ()->alive ()) - return false; - - bool changed = false; - for (size_t i = 0; i < malive.size (); ++i) - if (! malive[i] && check_alive (i)) - { - changed = true; - malive[i] = true; - successor (i)->mark_alive (); - } - - return changed; -} - -llvm::TerminatorInst * -jit_terminator::to_llvm (void) const -{ - return llvm::cast<llvm::TerminatorInst> (jit_value::to_llvm ()); -} - -// -------------------- jit_call -------------------- -bool -jit_call::infer (void) -{ - // FIXME: explain algorithm - for (size_t i = 0; i < argument_count (); ++i) - { - already_infered[i] = argument_type (i); - if (! already_infered[i]) - return false; - } - - jit_type *infered = moperation.result (already_infered); - if (! infered && use_count ()) - { - std::stringstream ss; - ss << "Missing overload in type inference for "; - print (ss, 0); - fail (ss.str ()); - } - - if (infered != type ()) - { - stash_type (infered); - return true; - } - - return false; -} - // -------------------- jit_convert -------------------- jit_convert::jit_convert (llvm::Module *module, tree &tee) : iterator_count (0), short_count (0), breaking (false) @@ -2461,13 +152,13 @@ void jit_convert::visit_anon_fcn_handle (tree_anon_fcn_handle&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_argument_list (tree_argument_list&) { - fail (); + throw jit_fail_exception (); } void @@ -2569,25 +260,25 @@ void jit_convert::visit_global_command (tree_global_command&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_persistent_command (tree_persistent_command&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_decl_elt (tree_decl_elt&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_decl_init_list (tree_decl_init_list&) { - fail (); + throw jit_fail_exception (); } void @@ -2676,37 +367,37 @@ void jit_convert::visit_complex_for_command (tree_complex_for_command&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_octave_user_script (octave_user_script&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_octave_user_function (octave_user_function&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_octave_user_function_header (octave_user_function&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_octave_user_function_trailer (octave_user_function&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_function_def (tree_function_def&) { - fail (); + throw jit_fail_exception (); } void @@ -2718,7 +409,7 @@ void jit_convert::visit_if_clause (tree_if_clause&) { - fail (); + throw jit_fail_exception (); } void @@ -2821,25 +512,25 @@ void jit_convert::visit_matrix (tree_matrix&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_cell (tree_cell&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_multi_assignment (tree_multi_assignment&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_no_op_command (tree_no_op_command&) { - fail (); + throw jit_fail_exception (); } void @@ -2862,50 +553,50 @@ result = create<jit_const_complex> (cv); } else - fail ("Unknown constant"); + throw jit_fail_exception ("Unknown constant"); } void jit_convert::visit_fcn_handle (tree_fcn_handle&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_parameter_list (tree_parameter_list&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_postfix_expression (tree_postfix_expression&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_prefix_expression (tree_prefix_expression&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_return_command (tree_return_command&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_return_list (tree_return_list&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_simple_assignment (tree_simple_assignment& tsa) { if (tsa.op_type () != octave_value::op_asn_eq) - fail ("Unsupported assign"); + throw jit_fail_exception ("Unsupported assign"); // resolve rhs tree_expression *rhs = tsa.right_hand_side (); @@ -2970,31 +661,31 @@ void jit_convert::visit_switch_case (tree_switch_case&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_switch_case_list (tree_switch_case_list&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_switch_command (tree_switch_command&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_try_catch_command (tree_try_catch_command&) { - fail (); + throw jit_fail_exception (); } void jit_convert::visit_unwind_protect_command (tree_unwind_protect_command&) { - fail (); + throw jit_fail_exception (); } void @@ -3042,7 +733,7 @@ void jit_convert::visit_do_until_command (tree_do_until_command&) { - fail (); + throw jit_fail_exception (); } void @@ -3089,18 +780,18 @@ { std::string type = exp.type_tags (); if (! (type.size () == 1 && type[0] == '(')) - fail ("Unsupported index operation"); + throw jit_fail_exception ("Unsupported index operation"); std::list<tree_argument_list *> args = exp.arg_lists (); if (args.size () != 1) - fail ("Bad number of arguments in tree_index_expression"); + throw jit_fail_exception ("Bad number of arguments in tree_index_expression"); tree_argument_list *arg_list = args.front (); if (! arg_list) - fail ("null argument list"); + throw jit_fail_exception ("null argument list"); if (arg_list->size () != 1) - fail ("Bad number of arguments in arg_list"); + throw jit_fail_exception ("Bad number of arguments in arg_list"); tree_expression *tree_object = exp.expression (); jit_value *object = visit (tree_object); @@ -3114,7 +805,7 @@ jit_convert::do_assign (tree_expression *exp, jit_value *rhs, bool artificial) { if (! exp) - fail ("NULL lhs in assign"); + throw jit_fail_exception ("NULL lhs in assign"); if (isa<tree_identifier> (exp)) return do_assign (exp->name (), rhs, exp->print_result (), artificial); @@ -3134,7 +825,7 @@ return rhs; } else - fail ("Unsupported assignment"); + throw jit_fail_exception ("Unsupported assignment"); } jit_value * @@ -3679,7 +1370,7 @@ for (size_t i = 0; i < args.size (); ++i) args[i] = call.argument (i); - llvm::Value *ret = ol.call (args); + llvm::Value *ret = ol.call (builder, args); call.stash_llvm (ret); } @@ -3691,14 +1382,14 @@ arg = builder.CreateLoad (arg); const jit_function& ol = extract.overload (); - extract.stash_llvm (ol.call (arg)); + extract.stash_llvm (ol.call (builder, arg)); } void jit_convert::convert_llvm::visit (jit_store_argument& store) { const jit_function& ol = store.overload (); - llvm::Value *arg_value = ol.call (store.result ()); + llvm::Value *arg_value = ol.call (builder, store.result ()); llvm::Value *arg = arguments[store.name ()]; store.stash_llvm (builder.CreateStore (arg_value, arg)); } @@ -3717,13 +1408,13 @@ void jit_convert::convert_llvm::visit (jit_variable&) { - fail ("ERROR: SSA construction should remove all variables"); + throw jit_fail_exception ("ERROR: SSA construction should remove all variables"); } void jit_convert::convert_llvm::visit (jit_error_check& check) { - llvm::Value *cond = jit_typeinfo::insert_error_check (); + llvm::Value *cond = jit_typeinfo::insert_error_check (builder); llvm::Value *br = builder.CreateCondBr (cond, check.successor_llvm (0), check.successor_llvm (1)); check.stash_llvm (br); @@ -3742,14 +1433,14 @@ { const jit_function& ol = jit_typeinfo::get_grab (new_value->type ()); if (ol.valid ()) - assign.stash_llvm (ol.call (new_value)); + assign.stash_llvm (ol.call (builder, new_value)); } jit_value *overwrite = assign.overwrite (); if (isa<jit_assign_base> (overwrite)) { const jit_function& ol = jit_typeinfo::get_release (overwrite->type ()); - ol.call (overwrite); + ol.call (builder, overwrite); } }
--- a/src/pt-jit.h +++ b/src/pt-jit.h @@ -25,15 +25,8 @@ #ifdef HAVE_LLVM -#include <list> -#include <map> -#include <set> -#include <stdexcept> -#include <vector> -#include <stack> +#include "jit-ir.h" -#include "Array.h" -#include "Range.h" #include "pt-walk.h" // -------------------- Current status -------------------- @@ -61,1957 +54,6 @@ // 3. ... // --------------------------------------------------------- - -// we don't want to include llvm headers here, as they require -// __STDC_LIMIT_MACROS and __STDC_CONSTANT_MACROS be defined in the entire -// compilation unit -namespace llvm -{ - class Value; - class Module; - class FunctionPassManager; - class PassManager; - class ExecutionEngine; - class Function; - class BasicBlock; - class LLVMContext; - class Type; - class StructType; - class Twine; - class GlobalVariable; - class TerminatorInst; - class PHINode; -} - -// llvm doesn't provide this, and it's really useful for debugging -std::ostream& operator<< (std::ostream& os, const llvm::Value& v); - -class octave_base_value; -class octave_builtin; -class octave_value; -class tree; -class tree_expression; - -template <typename HOLDER_T, typename SUB_T> -class jit_internal_node; - -// jit_internal_list and jit_internal_node implement generic embedded doubly -// linked lists. List items extend from jit_internal_list, and can be placed -// in nodes of type jit_internal_node. We use CRTP twice. -template <typename LIST_T, typename NODE_T> -class -jit_internal_list -{ - friend class jit_internal_node<LIST_T, NODE_T>; -public: - jit_internal_list (void) : use_head (0), use_tail (0), muse_count (0) {} - - virtual ~jit_internal_list (void) - { - while (use_head) - use_head->stash_value (0); - } - - NODE_T *first_use (void) const { return use_head; } - - size_t use_count (void) const { return muse_count; } -private: - NODE_T *use_head; - NODE_T *use_tail; - size_t muse_count; -}; - -// a node for internal linked lists -template <typename LIST_T, typename NODE_T> -class -jit_internal_node -{ -public: - typedef jit_internal_list<LIST_T, NODE_T> jit_ilist; - - jit_internal_node (void) : mvalue (0), mnext (0), mprev (0) {} - - ~jit_internal_node (void) { remove (); } - - LIST_T *value (void) const { return mvalue; } - - void stash_value (LIST_T *avalue) - { - remove (); - - mvalue = avalue; - - if (mvalue) - { - jit_ilist *ilist = mvalue; - NODE_T *sthis = static_cast<NODE_T *> (this); - if (ilist->use_head) - { - ilist->use_tail->mnext = sthis; - mprev = ilist->use_tail; - } - else - ilist->use_head = sthis; - - ilist->use_tail = sthis; - ++ilist->muse_count; - } - } - - NODE_T *next (void) const { return mnext; } - - NODE_T *prev (void) const { return mprev; } -private: - void remove () - { - if (mvalue) - { - jit_ilist *ilist = mvalue; - if (mprev) - mprev->mnext = mnext; - else - // we are the use_head - ilist->use_head = mnext; - - if (mnext) - mnext->mprev = mprev; - else - // we are the use tail - ilist->use_tail = mprev; - - mnext = mprev = 0; - --ilist->muse_count; - mvalue = 0; - } - } - - LIST_T *mvalue; - NODE_T *mnext; - NODE_T *mprev; -}; - -// Use like: isa<jit_phi> (value) -// basically just a short cut type typing dyanmic_cast. -template <typename T, typename U> -bool isa (U *value) -{ - return dynamic_cast<T *> (value); -} - -// jit_range is compatable with the llvm range structure -struct -jit_range -{ - jit_range (const Range& from) : base (from.base ()), limit (from.limit ()), - inc (from.inc ()), nelem (from.nelem ()) - {} - - operator Range () const - { - return Range (base, limit, inc); - } - - bool all_elements_are_ints () const; - - double base; - double limit; - double inc; - octave_idx_type nelem; -}; - -std::ostream& operator<< (std::ostream& os, const jit_range& rng); - -// jit_array is compatable with the llvm array/matrix structures -template <typename T, typename U> -struct -jit_array -{ - jit_array (T& from) : array (new T (from)) - { - update (); - } - - void update (void) - { - ref_count = array->jit_ref_count (); - slice_data = array->jit_slice_data () - 1; - slice_len = array->capacity (); - dimensions = array->jit_dimensions (); - } - - void update (T *aarray) - { - array = aarray; - update (); - } - - operator T () const - { - return *array; - } - - int *ref_count; - - U *slice_data; - octave_idx_type slice_len; - octave_idx_type *dimensions; - - T *array; -}; - -typedef jit_array<NDArray, double> jit_matrix; - -std::ostream& operator<< (std::ostream& os, const jit_matrix& mat); - -class jit_type; -class jit_value; - -// calling convention -namespace -jit_convention -{ - enum - type - { - // internal to jit - internal, - - // an external C call - external, - - length - }; -} - -// Used to keep track of estimated (infered) types during JIT. This is a -// hierarchical type system which includes both concrete and abstract types. -// -// Current, we only support any and scalar types. If we can't figure out what -// type a variable is, we assign it the any type. This allows us to generate -// code even for the case of poor type inference. -class -jit_type -{ -public: - typedef llvm::Value *(*convert_fn) (llvm::Value *); - - jit_type (const std::string& aname, jit_type *aparent, llvm::Type *allvm_type, - int aid); - - // a user readable type name - const std::string& name (void) const { return mname; } - - // a unique id for the type - int type_id (void) const { return mid; } - - // An abstract base type, may be null - jit_type *parent (void) const { return mparent; } - - // convert to an llvm type - llvm::Type *to_llvm (void) const { return llvm_type; } - - // how this type gets passed as a function argument - llvm::Type *to_llvm_arg (void) const; - - size_t depth (void) const { return mdepth; } - - bool sret (jit_convention::type cc) const { return msret[cc]; } - - void mark_sret (jit_convention::type cc = jit_convention::external) - { msret[cc] = true; } - - bool pointer_arg (jit_convention::type cc) const { return mpointer_arg[cc]; } - - void mark_pointer_arg (jit_convention::type cc = jit_convention::external) - { mpointer_arg[cc] = true; } - - convert_fn pack (jit_convention::type cc) { return mpack[cc]; } - - void set_pack (jit_convention::type cc, convert_fn fn) { mpack[cc] = fn; } - - convert_fn unpack (jit_convention::type cc) { return munpack[cc]; } - - void set_unpack (jit_convention::type cc, convert_fn fn) - { munpack[cc] = fn; } - - llvm::Type *packed_type (jit_convention::type cc) - { return mpacked_type[cc]; } - - void set_packed_type (jit_convention::type cc, llvm::Type *ty) - { mpacked_type[cc] = ty; } -private: - std::string mname; - jit_type *mparent; - llvm::Type *llvm_type; - int mid; - size_t mdepth; - - bool msret[jit_convention::length]; - bool mpointer_arg[jit_convention::length]; - - convert_fn mpack[jit_convention::length]; - convert_fn munpack[jit_convention::length]; - - llvm::Type *mpacked_type[jit_convention::length]; -}; - -// seperate print function to allow easy printing if type is null -std::ostream& jit_print (std::ostream& os, jit_type *atype); - -#define ASSIGN_ARG(i) the_args[i] = arg ## i; -#define JIT_EXPAND(ret, fname, type, isconst, N) \ - ret fname (JIT_PARAM_ARGS OCT_MAKE_DECL_LIST (type, arg, N)) isconst \ - { \ - std::vector<type> the_args (N); \ - OCT_ITERATE_MACRO (ASSIGN_ARG, N); \ - return fname (JIT_PARAMS the_args); \ - } - -// provides a mechanism for calling -class -jit_function -{ - friend std::ostream& operator<< (std::ostream& os, const jit_function& fn); -public: - jit_function (); - - jit_function (llvm::Module *amodule, jit_convention::type acall_conv, - const llvm::Twine& aname, jit_type *aresult, - const std::vector<jit_type *>& aargs); - - jit_function (const jit_function& fn, jit_type *aresult, - const std::vector<jit_type *>& aargs); - - jit_function (const jit_function& fn); - - bool valid (void) const { return llvm_function; } - - std::string name (void) const; - - llvm::BasicBlock *new_block (const std::string& aname = "body", - llvm::BasicBlock *insert_before = 0); - - llvm::Value *call (const std::vector<jit_value *>& in_args) const; - - llvm::Value *call (const std::vector<llvm::Value *>& in_args) const; - -#define JIT_PARAM_ARGS -#define JIT_PARAMS -#define JIT_CALL(N) JIT_EXPAND (llvm::Value *, call, llvm::Value *, const, N) - - JIT_CALL (0); - JIT_CALL (1); - JIT_CALL (2); - JIT_CALL (3); - JIT_CALL (4); - JIT_CALL (5); - -#undef JIT_CALL - -#define JIT_CALL(N) JIT_EXPAND (llvm::Value *, call, jit_value *, const, N) - - JIT_CALL (1); - JIT_CALL (2); - -#undef JIT_CALL -#undef JIT_PARAMS -#undef JIT_PARAM_ARGS - - llvm::Value *argument (size_t idx) const; - - void do_return (llvm::Value *rval = 0); - - llvm::Function *to_llvm (void) const { return llvm_function; } - - // If true, then the return value is passed as a pointer in the first argument - bool sret (void) const { return mresult && mresult->sret (call_conv); } - - bool can_error (void) const { return mcan_error; } - - void mark_can_error (void) { mcan_error = true; } - - jit_type *result (void) const { return mresult; } - - jit_type *argument_type (size_t idx) const - { - assert (idx < args.size ()); - return args[idx]; - } - - const std::vector<jit_type *>& arguments (void) const { return args; } -private: - llvm::Module *module; - llvm::Function *llvm_function; - jit_type *mresult; - std::vector<jit_type *> args; - jit_convention::type call_conv; - bool mcan_error; -}; - -std::ostream& operator<< (std::ostream& os, const jit_function& fn); - - -// Keeps track of overloads for a builtin function. Used for both type inference -// and code generation. -class -jit_operation -{ -public: - void add_overload (const jit_function& func) - { - add_overload (func, func.arguments ()); - } - - void add_overload (const jit_function& func, - const std::vector<jit_type*>& args); - - const jit_function& overload (const std::vector<jit_type *>& types) const; - - jit_type *result (const std::vector<jit_type *>& types) const - { - const jit_function& temp = overload (types); - return temp.result (); - } - -#define JIT_PARAMS -#define JIT_PARAM_ARGS -#define JIT_OVERLOAD(N) \ - JIT_EXPAND (const jit_function&, overload, jit_type *, const, N) \ - JIT_EXPAND (jit_type *, result, jit_type *, const, N) - - JIT_OVERLOAD (1); - JIT_OVERLOAD (2); - JIT_OVERLOAD (3); - -#undef JIT_PARAMS -#undef JIT_PARAM_ARGS - - const std::string& name (void) const { return mname; } - - void stash_name (const std::string& aname) { mname = aname; } -private: - Array<octave_idx_type> to_idx (const std::vector<jit_type*>& types) const; - - std::vector<Array<jit_function> > overloads; - - std::string mname; -}; - -// Get information and manipulate jit types. -class -jit_typeinfo -{ -public: - static void initialize (llvm::Module *m, llvm::ExecutionEngine *e); - - static jit_type *join (jit_type *lhs, jit_type *rhs) - { - return instance->do_join (lhs, rhs); - } - - static jit_type *get_any (void) { return instance->any; } - - static jit_type *get_matrix (void) { return instance->matrix; } - - static jit_type *get_scalar (void) { return instance->scalar; } - - static llvm::Type *get_scalar_llvm (void) - { return instance->scalar->to_llvm (); } - - static jit_type *get_range (void) { return instance->range; } - - static jit_type *get_string (void) { return instance->string; } - - static jit_type *get_bool (void) { return instance->boolean; } - - static jit_type *get_index (void) { return instance->index; } - - static llvm::Type *get_index_llvm (void) - { return instance->index->to_llvm (); } - - static jit_type *get_complex (void) { return instance->complex; } - - static jit_type *type_of (const octave_value& ov) - { - return instance->do_type_of (ov); - } - - static const jit_operation& binary_op (int op) - { - return instance->do_binary_op (op); - } - - static const jit_operation& grab (void) { return instance->grab_fn; } - - static const jit_function& get_grab (jit_type *type) - { - return instance->grab_fn.overload (type); - } - - static const jit_operation& release (void) - { - return instance->release_fn; - } - - static const jit_function& get_release (jit_type *type) - { - return instance->release_fn.overload (type); - } - - static const jit_operation& print_value (void) - { - return instance->print_fn; - } - - static const jit_operation& for_init (void) - { - return instance->for_init_fn; - } - - static const jit_operation& for_check (void) - { - return instance->for_check_fn; - } - - static const jit_operation& for_index (void) - { - return instance->for_index_fn; - } - - static const jit_operation& make_range (void) - { - return instance->make_range_fn; - } - - static const jit_operation& paren_subsref (void) - { - return instance->paren_subsref_fn; - } - - static const jit_operation& paren_subsasgn (void) - { - return instance->paren_subsasgn_fn; - } - - static const jit_operation& logically_true (void) - { - return instance->logically_true_fn; - } - - static const jit_operation& cast (jit_type *result) - { - return instance->do_cast (result); - } - - static const jit_function& cast (jit_type *to, jit_type *from) - { - return instance->do_cast (to, from); - } - - static llvm::Value *insert_error_check (void) - { - return instance->do_insert_error_check (); - } -private: - jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e); - - // FIXME: Do these methods really need to be in jit_typeinfo? - jit_type *do_join (jit_type *lhs, jit_type *rhs) - { - // empty case - if (! lhs) - return rhs; - - if (! rhs) - return lhs; - - // check for a shared parent - while (lhs != rhs) - { - if (lhs->depth () > rhs->depth ()) - lhs = lhs->parent (); - else if (lhs->depth () < rhs->depth ()) - rhs = rhs->parent (); - else - { - // we MUST have depth > 0 as any is the base type of everything - do - { - lhs = lhs->parent (); - rhs = rhs->parent (); - } - while (lhs != rhs); - } - } - - return lhs; - } - - jit_type *do_difference (jit_type *lhs, jit_type *) - { - // FIXME: Maybe we can do something smarter? - return lhs; - } - - jit_type *do_type_of (const octave_value &ov) const; - - const jit_operation& do_binary_op (int op) const - { - assert (static_cast<size_t>(op) < binary_ops.size ()); - return binary_ops[op]; - } - - const jit_operation& do_cast (jit_type *to) - { - static jit_operation null_function; - if (! to) - return null_function; - - size_t id = to->type_id (); - if (id >= casts.size ()) - return null_function; - return casts[id]; - } - - const jit_function& do_cast (jit_type *to, jit_type *from) - { - return do_cast (to).overload (from); - } - - jit_type *new_type (const std::string& name, jit_type *parent, - llvm::Type *llvm_type); - - - void add_print (jit_type *ty); - - void add_binary_op (jit_type *ty, int op, int llvm_op); - - void add_binary_icmp (jit_type *ty, int op, int llvm_op); - - void add_binary_fcmp (jit_type *ty, int op, int llvm_op); - - jit_function create_function (jit_convention::type cc, - const llvm::Twine& name, jit_type *ret, - const std::vector<jit_type *>& args - = std::vector<jit_type *> ()); - -#define JIT_PARAM_ARGS jit_convention::type cc, const llvm::Twine& name, \ - jit_type *ret, -#define JIT_PARAMS cc, name, ret, -#define CREATE_FUNCTION(N) JIT_EXPAND(jit_function, create_function, \ - jit_type *, /* empty */, N) - - CREATE_FUNCTION(1); - CREATE_FUNCTION(2); - CREATE_FUNCTION(3); - CREATE_FUNCTION(4); - -#undef JIT_PARAM_ARGS -#undef JIT_PARAMS -#undef CREATE_FUNCTION - - jit_function create_identity (jit_type *type); - - llvm::Value *do_insert_error_check (void); - - void add_builtin (const std::string& name); - - void register_intrinsic (const std::string& name, size_t id, - jit_type *result, jit_type *arg0) - { - std::vector<jit_type *> args (1, arg0); - register_intrinsic (name, id, result, args); - } - - void register_intrinsic (const std::string& name, size_t id, jit_type *result, - const std::vector<jit_type *>& args); - - void register_generic (const std::string& name, jit_type *result, - jit_type *arg0) - { - std::vector<jit_type *> args (1, arg0); - register_generic (name, result, args); - } - - void register_generic (const std::string& name, jit_type *result, - const std::vector<jit_type *>& args); - - octave_builtin *find_builtin (const std::string& name); - - jit_function mirror_binary (const jit_function& fn); - - llvm::Function *wrap_complex (llvm::Function *wrap); - - static llvm::Value *pack_complex (llvm::Value *cplx); - - static llvm::Value *unpack_complex (llvm::Value *result); - - llvm::Value *complex_real (llvm::Value *cx); - - llvm::Value *complex_real (llvm::Value *cx, llvm::Value *real); - - llvm::Value *complex_imag (llvm::Value *cx); - - llvm::Value *complex_imag (llvm::Value *cx, llvm::Value *imag); - - llvm::Value *complex_new (llvm::Value *real, llvm::Value *imag); - - void create_int (size_t nbits); - - jit_type *intN (size_t nbits) const; - - static jit_typeinfo *instance; - - llvm::Module *module; - llvm::ExecutionEngine *engine; - int next_id; - - llvm::GlobalVariable *lerror_state; - - std::vector<jit_type*> id_to_type; - jit_type *any; - jit_type *matrix; - jit_type *scalar; - jit_type *range; - jit_type *string; - jit_type *boolean; - jit_type *index; - jit_type *complex; - jit_type *unknown_function; - std::map<size_t, jit_type *> ints; - std::map<std::string, jit_type *> builtins; - - llvm::StructType *complex_ret; - - std::vector<jit_operation> binary_ops; - jit_operation grab_fn; - jit_operation release_fn; - jit_operation print_fn; - jit_operation for_init_fn; - jit_operation for_check_fn; - jit_operation for_index_fn; - jit_operation logically_true_fn; - jit_operation make_range_fn; - jit_operation paren_subsref_fn; - jit_operation paren_subsasgn_fn; - - // type id -> cast function TO that type - std::vector<jit_operation> casts; - - // type id -> identity function - std::vector<jit_function> identities; -}; - -// The low level octave jit ir -// this ir is close to llvm, but contains information for doing type inference. -// We convert the octave parse tree to this IR directly. - -#define JIT_VISIT_IR_NOTEMPLATE \ - JIT_METH(block); \ - JIT_METH(branch); \ - JIT_METH(cond_branch); \ - JIT_METH(call); \ - JIT_METH(extract_argument); \ - JIT_METH(store_argument); \ - JIT_METH(phi); \ - JIT_METH(variable); \ - JIT_METH(error_check); \ - JIT_METH(assign) \ - JIT_METH(argument) - -#define JIT_VISIT_IR_CONST \ - JIT_METH(const_bool); \ - JIT_METH(const_scalar); \ - JIT_METH(const_complex); \ - JIT_METH(const_index); \ - JIT_METH(const_string); \ - JIT_METH(const_range) - -#define JIT_VISIT_IR_CLASSES \ - JIT_VISIT_IR_NOTEMPLATE \ - JIT_VISIT_IR_CONST - -// forward declare all ir classes -#define JIT_METH(cname) \ - class jit_ ## cname; - -JIT_VISIT_IR_NOTEMPLATE - -#undef JIT_METH - -class jit_convert; - -// ABCs which aren't included in JIT_VISIT_IR_ALL -class jit_instruction; -class jit_terminator; - -template <typename T, jit_type *(*EXTRACT_T)(void), typename PASS_T = T, - bool QUOTE=false> -class jit_const; - -typedef jit_const<bool, jit_typeinfo::get_bool> jit_const_bool; -typedef jit_const<double, jit_typeinfo::get_scalar> jit_const_scalar; -typedef jit_const<Complex, jit_typeinfo::get_complex> jit_const_complex; -typedef jit_const<octave_idx_type, jit_typeinfo::get_index> jit_const_index; - -typedef jit_const<std::string, jit_typeinfo::get_string, const std::string&, - true> jit_const_string; -typedef jit_const<jit_range, jit_typeinfo::get_range, const jit_range&> -jit_const_range; - -class jit_ir_walker; -class jit_use; - -class -jit_value : public jit_internal_list<jit_value, jit_use> -{ -public: - jit_value (void) : llvm_value (0), ty (0), mlast_use (0), - min_worklist (false) {} - - virtual ~jit_value (void); - - bool in_worklist (void) const - { - return min_worklist; - } - - void stash_in_worklist (bool ain_worklist) - { - min_worklist = ain_worklist; - } - - // The block of the first use which is not a jit_error_check - // So this is not necessarily first_use ()->parent (). - jit_block *first_use_block (void); - - // replace all uses with - virtual void replace_with (jit_value *value); - - jit_type *type (void) const { return ty; } - - llvm::Type *type_llvm (void) const - { - return ty ? ty->to_llvm () : 0; - } - - const std::string& type_name (void) const - { - return ty->name (); - } - - void stash_type (jit_type *new_ty) { ty = new_ty; } - - std::string print_string (void) - { - std::stringstream ss; - print (ss); - return ss.str (); - } - - jit_instruction *last_use (void) const { return mlast_use; } - - void stash_last_use (jit_instruction *alast_use) - { - mlast_use = alast_use; - } - - virtual bool needs_release (void) const { return false; } - - virtual std::ostream& print (std::ostream& os, size_t indent = 0) const = 0; - - virtual std::ostream& short_print (std::ostream& os) const - { return print (os); } - - virtual void accept (jit_ir_walker& walker) = 0; - - bool has_llvm (void) const - { - return llvm_value; - } - - llvm::Value *to_llvm (void) const - { - assert (llvm_value); - return llvm_value; - } - - void stash_llvm (llvm::Value *compiled) - { - llvm_value = compiled; - } - -protected: - std::ostream& print_indent (std::ostream& os, size_t indent = 0) const - { - for (size_t i = 0; i < indent * 8; ++i) - os << " "; - return os; - } - - llvm::Value *llvm_value; -private: - jit_type *ty; - jit_instruction *mlast_use; - bool min_worklist; -}; - -std::ostream& operator<< (std::ostream& os, const jit_value& value); -std::ostream& jit_print (std::ostream& os, jit_value *avalue); - -class -jit_use : public jit_internal_node<jit_value, jit_use> -{ -public: - jit_use (void) : muser (0), mindex (0) {} - - // we should really have a move operator, but not until c++11 :( - jit_use (const jit_use& use) : muser (0), mindex (0) - { - *this = use; - } - - jit_use& operator= (const jit_use& use) - { - stash_value (use.value (), use.user (), use.index ()); - return *this; - } - - size_t index (void) const { return mindex; } - - jit_instruction *user (void) const { return muser; } - - jit_block *user_parent (void) const; - - std::list<jit_block *> user_parent_location (void) const; - - void stash_value (jit_value *avalue, jit_instruction *auser = 0, - size_t aindex = -1) - { - jit_internal_node::stash_value (avalue); - mindex = aindex; - muser = auser; - } -private: - jit_instruction *muser; - size_t mindex; -}; - -class -jit_instruction : public jit_value -{ -public: - // FIXME: this code could be so much pretier with varadic templates... - jit_instruction (void) : mid (next_id ()), mparent (0) - {} - - jit_instruction (size_t nargs) : mid (next_id ()), mparent (0) - { - already_infered.reserve (nargs); - marguments.reserve (nargs); - } - -#define STASH_ARG(i) stash_argument (i, arg ## i); -#define JIT_INSTRUCTION_CTOR(N) \ - jit_instruction (OCT_MAKE_DECL_LIST (jit_value *, arg, N)) \ - : already_infered (N), marguments (N), mid (next_id ()), mparent (0) \ - { \ - OCT_ITERATE_MACRO (STASH_ARG, N); \ - } - - JIT_INSTRUCTION_CTOR(1) - JIT_INSTRUCTION_CTOR(2) - JIT_INSTRUCTION_CTOR(3) - JIT_INSTRUCTION_CTOR(4) - -#undef STASH_ARG -#undef JIT_INSTRUCTION_CTOR - - static void reset_ids (void) - { - next_id (true); - } - - jit_value *argument (size_t i) const - { - return marguments[i].value (); - } - - llvm::Value *argument_llvm (size_t i) const - { - assert (argument (i)); - return argument (i)->to_llvm (); - } - - jit_type *argument_type (size_t i) const - { - return argument (i)->type (); - } - - llvm::Type *argument_type_llvm (size_t i) const - { - assert (argument (i)); - return argument_type (i)->to_llvm (); - } - - std::ostream& print_argument (std::ostream& os, size_t i) const - { - if (argument (i)) - return argument (i)->short_print (os); - else - return os << "NULL"; - } - - void stash_argument (size_t i, jit_value *arg) - { - marguments[i].stash_value (arg, this, i); - } - - void push_argument (jit_value *arg) - { - marguments.push_back (jit_use ()); - stash_argument (marguments.size () - 1, arg); - already_infered.push_back (0); - } - - size_t argument_count (void) const - { - return marguments.size (); - } - - void resize_arguments (size_t acount, jit_value *adefault = 0) - { - size_t old = marguments.size (); - marguments.resize (acount); - already_infered.resize (acount); - - if (adefault) - for (size_t i = old; i < acount; ++i) - stash_argument (i, adefault); - } - - const std::vector<jit_use>& arguments (void) const { return marguments; } - - // argument types which have been infered already - const std::vector<jit_type *>& argument_types (void) const - { return already_infered; } - - virtual void push_variable (void) {} - - virtual void pop_variable (void) {} - - virtual void construct_ssa (void) - { - do_construct_ssa (0, argument_count ()); - } - - virtual bool infer (void) { return false; } - - void remove (void); - - virtual std::ostream& short_print (std::ostream& os) const; - - jit_block *parent (void) const { return mparent; } - - std::list<jit_instruction *>::iterator location (void) const - { - return mlocation; - } - - llvm::BasicBlock *parent_llvm (void) const; - - void stash_parent (jit_block *aparent, - std::list<jit_instruction *>::iterator alocation) - { - mparent = aparent; - mlocation = alocation; - } - - size_t id (void) const { return mid; } -protected: - - // Do SSA replacement on arguments in [start, end) - void do_construct_ssa (size_t start, size_t end); - - std::vector<jit_type *> already_infered; -private: - static size_t next_id (bool reset = false) - { - static size_t ret = 0; - if (reset) - return ret = 0; - - return ret++; - } - - std::vector<jit_use> marguments; - - size_t mid; - jit_block *mparent; - std::list<jit_instruction *>::iterator mlocation; -}; - -// defnie accept methods for subclasses -#define JIT_VALUE_ACCEPT \ - virtual void accept (jit_ir_walker& walker); - -// for use as a dummy argument during conversion to LLVM -class -jit_argument : public jit_value -{ -public: - jit_argument (jit_type *atype, llvm::Value *avalue) - { - stash_type (atype); - stash_llvm (avalue); - } - - virtual std::ostream& print (std::ostream& os, size_t indent = 0) const - { - print_indent (os, indent); - return jit_print (os, type ()) << ": DUMMY"; - } - - JIT_VALUE_ACCEPT; -}; - -template <typename T, jit_type *(*EXTRACT_T)(void), typename PASS_T, - bool QUOTE> -class -jit_const : public jit_value -{ -public: - typedef PASS_T pass_t; - - jit_const (PASS_T avalue) : mvalue (avalue) - { - stash_type (EXTRACT_T ()); - } - - PASS_T value (void) const { return mvalue; } - - virtual std::ostream& print (std::ostream& os, size_t indent = 0) const - { - print_indent (os, indent); - jit_print (os, type ()) << ": "; - if (QUOTE) - os << "\""; - os << mvalue; - if (QUOTE) - os << "\""; - return os; - } - - JIT_VALUE_ACCEPT; -private: - T mvalue; -}; - -class jit_phi_incomming; - -class -jit_block : public jit_value, public jit_internal_list<jit_block, - jit_phi_incomming> -{ - typedef jit_internal_list<jit_block, jit_phi_incomming> ILIST_T; -public: - typedef std::list<jit_instruction *> instruction_list; - typedef instruction_list::iterator iterator; - typedef instruction_list::const_iterator const_iterator; - - typedef std::set<jit_block *> df_set; - typedef df_set::const_iterator df_iterator; - - static const size_t NO_ID = static_cast<size_t> (-1); - - jit_block (const std::string& aname, size_t avisit_count = 0) - : mvisit_count (avisit_count), mid (NO_ID), idom (0), mname (aname), - malive (false) - {} - - virtual void replace_with (jit_value *value); - - void replace_in_phi (jit_block *ablock, jit_block *with); - - // we have a new internal list, but we want to stay compatable with jit_value - jit_use *first_use (void) const { return jit_value::first_use (); } - - size_t use_count (void) const { return jit_value::use_count (); } - - // if a block is alive, then it might be visited during execution - bool alive (void) const { return malive; } - - void mark_alive (void) { malive = true; } - - // If we can merge with a successor, do so and return the now empty block - jit_block *maybe_merge (); - - // merge another block into this block, leaving the merge block empty - void merge (jit_block& merge); - - const std::string& name (void) const { return mname; } - - jit_instruction *prepend (jit_instruction *instr); - - jit_instruction *prepend_after_phi (jit_instruction *instr); - - template <typename T> - T *append (T *instr) - { - internal_append (instr); - return instr; - } - - jit_instruction *insert_before (iterator loc, jit_instruction *instr); - - jit_instruction *insert_before (jit_instruction *loc, jit_instruction *instr) - { - return insert_before (loc->location (), instr); - } - - jit_instruction *insert_after (iterator loc, jit_instruction *instr); - - jit_instruction *insert_after (jit_instruction *loc, jit_instruction *instr) - { - return insert_after (loc->location (), instr); - } - - iterator remove (iterator iter) - { - jit_instruction *instr = *iter; - iter = instructions.erase (iter); - instr->stash_parent (0, instructions.end ()); - return iter; - } - - jit_terminator *terminator (void) const; - - // is the jump from pred alive? - bool branch_alive (jit_block *asucc) const; - - jit_block *successor (size_t i) const; - - size_t successor_count (void) const; - - iterator begin (void) { return instructions.begin (); } - - const_iterator begin (void) const { return instructions.begin (); } - - iterator end (void) { return instructions.end (); } - - const_iterator end (void) const { return instructions.end (); } - - iterator phi_begin (void); - - iterator phi_end (void); - - iterator nonphi_begin (void); - - // must label before id is valid - size_t id (void) const { return mid; } - - // dominance frontier - const df_set& df (void) const { return mdf; } - - df_iterator df_begin (void) const { return mdf.begin (); } - - df_iterator df_end (void) const { return mdf.end (); } - - // label with a RPO walk - void label (void) - { - size_t number = 0; - label (mvisit_count, number); - } - - void label (size_t avisit_count, size_t& number) - { - if (visited (avisit_count)) - return; - - for (jit_use *use = first_use (); use; use = use->next ()) - { - jit_block *pred = use->user_parent (); - pred->label (avisit_count, number); - } - - mid = number++; - } - - // See for idom computation algorithm - // Cooper, Keith D.; Harvey, Timothy J; and Kennedy, Ken (2001). - // "A Simple, Fast Dominance Algorithm" - void compute_idom (jit_block *entry_block) - { - bool changed; - entry_block->idom = entry_block; - do - changed = update_idom (mvisit_count); - while (changed); - } - - // compute dominance frontier - void compute_df (void) - { - compute_df (mvisit_count); - } - - void create_dom_tree (void) - { - create_dom_tree (mvisit_count); - } - - jit_block *dom_successor (size_t idx) const - { - return dom_succ[idx]; - } - - size_t dom_successor_count (void) const - { - return dom_succ.size (); - } - - // call pop_varaible on all instructions - void pop_all (void); - - virtual std::ostream& print (std::ostream& os, size_t indent = 0) const - { - print_indent (os, indent); - short_print (os) << ": %pred = "; - for (jit_use *use = first_use (); use; use = use->next ()) - { - jit_block *pred = use->user_parent (); - os << *pred; - if (use->next ()) - os << ", "; - } - os << std::endl; - - for (const_iterator iter = begin (); iter != end (); ++iter) - { - jit_instruction *instr = *iter; - instr->print (os, indent + 1) << std::endl; - } - return os; - } - - // ... - jit_block *maybe_split (jit_convert& convert, jit_block *asuccessor); - - jit_block *maybe_split (jit_convert& convert, jit_block& asuccessor) - { - return maybe_split (convert, &asuccessor); - } - - // print dominator infomration - std::ostream& print_dom (std::ostream& os) const; - - virtual std::ostream& short_print (std::ostream& os) const - { - os << mname; - if (mid != NO_ID) - os << mid; - return os; - } - - llvm::BasicBlock *to_llvm (void) const; - - std::list<jit_block *>::iterator location (void) const - { return mlocation; } - - void stash_location (std::list<jit_block *>::iterator alocation) - { mlocation = alocation; } - - // used to prevent visiting the same node twice in the graph - size_t visit_count (void) const { return mvisit_count; } - - // check if this node has been visited yet at the given visit count. If we - // have not been visited yet, mark us as visited. - bool visited (size_t avisit_count) - { - if (mvisit_count <= avisit_count) - { - mvisit_count = avisit_count + 1; - return false; - } - - return true; - } - - JIT_VALUE_ACCEPT; -private: - void internal_append (jit_instruction *instr); - - void compute_df (size_t avisit_count); - - bool update_idom (size_t avisit_count); - - void create_dom_tree (size_t avisit_count); - - static jit_block *idom_intersect (jit_block *i, jit_block *j); - - size_t mvisit_count; - size_t mid; - jit_block *idom; - df_set mdf; - std::vector<jit_block *> dom_succ; - std::string mname; - instruction_list instructions; - bool malive; - std::list<jit_block *>::iterator mlocation; -}; - -// keeps track of phi functions that use a block on incomming edges -class -jit_phi_incomming : public jit_internal_node<jit_block, jit_phi_incomming> -{ -public: - jit_phi_incomming (void) : muser (0) {} - - jit_phi_incomming (jit_phi *auser) : muser (auser) {} - - jit_phi_incomming (const jit_phi_incomming& use) : jit_internal_node () - { - *this = use; - } - - jit_phi_incomming& operator= (const jit_phi_incomming& use) - { - stash_value (use.value ()); - muser = use.muser; - return *this; - } - - jit_phi *user (void) const { return muser; } - - jit_block *user_parent (void) const; -private: - jit_phi *muser; -}; - -// A non-ssa variable -class -jit_variable : public jit_value -{ -public: - jit_variable (const std::string& aname) : mname (aname), mlast_use (0) {} - - const std::string &name (void) const { return mname; } - - // manipulate the value_stack, for use during SSA construction. The top of the - // value stack represents the current value for this variable - bool has_top (void) const - { - return ! value_stack.empty (); - } - - jit_value *top (void) const - { - return value_stack.top (); - } - - void push (jit_instruction *v) - { - value_stack.push (v); - mlast_use = v; - } - - void pop (void) - { - value_stack.pop (); - } - - jit_instruction *last_use (void) const - { - return mlast_use; - } - - void stash_last_use (jit_instruction *instr) - { - mlast_use = instr; - } - - // blocks in which we are used - void use_blocks (jit_block::df_set& result) - { - jit_use *use = first_use (); - while (use) - { - result.insert (use->user_parent ()); - use = use->next (); - } - } - - virtual std::ostream& print (std::ostream& os, size_t indent = 0) const - { - return print_indent (os, indent) << mname; - } - - JIT_VALUE_ACCEPT; -private: - std::string mname; - std::stack<jit_value *> value_stack; - jit_instruction *mlast_use; -}; - -class -jit_assign_base : public jit_instruction -{ -public: - jit_assign_base (jit_variable *adest) : jit_instruction (), mdest (adest) {} - - jit_assign_base (jit_variable *adest, size_t npred) : jit_instruction (npred), - mdest (adest) {} - - jit_assign_base (jit_variable *adest, jit_value *arg0, jit_value *arg1) - : jit_instruction (arg0, arg1), mdest (adest) {} - - jit_variable *dest (void) const { return mdest; } - - virtual void push_variable (void) - { - mdest->push (this); - } - - virtual void pop_variable (void) - { - mdest->pop (); - } - - virtual std::ostream& short_print (std::ostream& os) const - { - if (type ()) - jit_print (os, type ()) << ": "; - - dest ()->short_print (os); - return os << "#" << id (); - } -private: - jit_variable *mdest; -}; - -class -jit_assign : public jit_assign_base -{ -public: - jit_assign (jit_variable *adest, jit_value *asrc) - : jit_assign_base (adest, adest, asrc), martificial (false) {} - - jit_value *overwrite (void) const - { - return argument (0); - } - - jit_value *src (void) const - { - return argument (1); - } - - // variables don't get modified in an SSA, but COW requires we modify - // variables. An artificial assign is for when a variable gets modified. We - // need an assign in the SSA, but the reference counts shouldn't be updated. - bool artificial (void) const { return martificial; } - - void mark_artificial (void) { martificial = true; } - - virtual bool infer (void) - { - jit_type *stype = src ()->type (); - if (stype != type()) - { - stash_type (stype); - return true; - } - - return false; - } - - virtual std::ostream& print (std::ostream& os, size_t indent = 0) const - { - print_indent (os, indent) << *this << " = " << *src (); - - if (artificial ()) - os << " [artificial]"; - - return os; - } - - JIT_VALUE_ACCEPT; -private: - bool martificial; -}; - -class -jit_phi : public jit_assign_base -{ -public: - jit_phi (jit_variable *adest, size_t npred) - : jit_assign_base (adest, npred) - { - mincomming.reserve (npred); - } - - // removes arguments form dead incomming jumps - bool prune (void); - - void add_incomming (jit_block *from, jit_value *value) - { - push_argument (value); - mincomming.push_back (jit_phi_incomming (this)); - mincomming[mincomming.size () - 1].stash_value (from); - } - - jit_block *incomming (size_t i) const - { - return mincomming[i].value (); - } - - llvm::BasicBlock *incomming_llvm (size_t i) const - { - return incomming (i)->to_llvm (); - } - - virtual void construct_ssa (void) {} - - virtual bool infer (void); - - virtual std::ostream& print (std::ostream& os, size_t indent = 0) const - { - std::stringstream ss; - print_indent (ss, indent); - short_print (ss) << " phi "; - std::string ss_str = ss.str (); - std::string indent_str (ss_str.size (), ' '); - os << ss_str; - - for (size_t i = 0; i < argument_count (); ++i) - { - if (i > 0) - os << indent_str; - os << "| "; - - os << *incomming (i) << " -> "; - os << *argument (i); - - if (i + 1 < argument_count ()) - os << std::endl; - } - - return os; - } - - llvm::PHINode *to_llvm (void) const; - - JIT_VALUE_ACCEPT; -private: - std::vector<jit_phi_incomming> mincomming; -}; - -class -jit_terminator : public jit_instruction -{ -public: -#define JIT_TERMINATOR_CONST(N) \ - jit_terminator (size_t asuccessor_count, \ - OCT_MAKE_DECL_LIST (jit_value *, arg, N)) \ - : jit_instruction (OCT_MAKE_ARG_LIST (arg, N)), \ - malive (asuccessor_count, false) {} - - JIT_TERMINATOR_CONST (1) - JIT_TERMINATOR_CONST (2) - JIT_TERMINATOR_CONST (3) - -#undef JIT_TERMINATOR_CONST - - jit_block *successor (size_t idx = 0) const - { - return static_cast<jit_block *> (argument (idx)); - } - - llvm::BasicBlock *successor_llvm (size_t idx = 0) const - { - return successor (idx)->to_llvm (); - } - - size_t successor_index (const jit_block *asuccessor) const; - - std::ostream& print_successor (std::ostream& os, size_t idx = 0) const - { - if (alive (idx)) - os << "[live] "; - else - os << "[dead] "; - - return successor (idx)->short_print (os); - } - - // Check if the jump to successor is live - bool alive (const jit_block *asuccessor) const - { - return alive (successor_index (asuccessor)); - } - - bool alive (size_t idx) const { return malive[idx]; } - - bool alive (int idx) const { return malive[idx]; } - - size_t successor_count (void) const { return malive.size (); } - - virtual bool infer (void); - - llvm::TerminatorInst *to_llvm (void) const; -protected: - virtual bool check_alive (size_t) const { return true; } -private: - std::vector<bool> malive; -}; - -class -jit_branch : public jit_terminator -{ -public: - jit_branch (jit_block *succ) : jit_terminator (1, succ) {} - - virtual size_t successor_count (void) const { return 1; } - - virtual std::ostream& print (std::ostream& os, size_t indent = 0) const - { - print_indent (os, indent) << "branch: "; - return print_successor (os); - } - - JIT_VALUE_ACCEPT; -}; - -class -jit_cond_branch : public jit_terminator -{ -public: - jit_cond_branch (jit_value *c, jit_block *ctrue, jit_block *cfalse) - : jit_terminator (2, ctrue, cfalse, c) {} - - jit_value *cond (void) const { return argument (2); } - - std::ostream& print_cond (std::ostream& os) const - { - return cond ()->short_print (os); - } - - llvm::Value *cond_llvm (void) const - { - return cond ()->to_llvm (); - } - - virtual size_t successor_count (void) const { return 2; } - - virtual std::ostream& print (std::ostream& os, size_t indent = 0) const - { - print_indent (os, indent) << "cond_branch: "; - print_cond (os) << ", "; - print_successor (os, 0) << ", "; - return print_successor (os, 1); - } - - JIT_VALUE_ACCEPT; -}; - -class -jit_call : public jit_instruction -{ -public: -#define JIT_CALL_CONST(N) \ - jit_call (const jit_operation& aoperation, \ - OCT_MAKE_DECL_LIST (jit_value *, arg, N)) \ - : jit_instruction (OCT_MAKE_ARG_LIST (arg, N)), moperation (aoperation) {} \ - \ - jit_call (const jit_operation& (*aoperation) (void), \ - OCT_MAKE_DECL_LIST (jit_value *, arg, N)) \ - : jit_instruction (OCT_MAKE_ARG_LIST (arg, N)), moperation (aoperation ()) \ - {} - - JIT_CALL_CONST (1) - JIT_CALL_CONST (2) - JIT_CALL_CONST (3) - JIT_CALL_CONST (4) - -#undef JIT_CALL_CONST - - - const jit_operation& operation (void) const { return moperation; } - - bool can_error (void) const - { - return overload ().can_error (); - } - - const jit_function& overload (void) const - { - return moperation.overload (argument_types ()); - } - - virtual bool needs_release (void) const - { - return type () && jit_typeinfo::get_release (type ()).valid (); - } - - virtual std::ostream& print (std::ostream& os, size_t indent = 0) const - { - print_indent (os, indent); - - if (use_count ()) - short_print (os) << " = "; - os << "call " << moperation.name () << " ("; - - for (size_t i = 0; i < argument_count (); ++i) - { - print_argument (os, i); - if (i + 1 < argument_count ()) - os << ", "; - } - return os << ")"; - } - - virtual bool infer (void); - - JIT_VALUE_ACCEPT; -private: - const jit_operation& moperation; -}; - -// FIXME: This is just ugly... -// checks error_state, if error_state is false then goto the normal branche, -// otherwise goto the error branch -class -jit_error_check : public jit_terminator -{ -public: - jit_error_check (jit_call *acheck_for, jit_block *normal, jit_block *error) - : jit_terminator (2, error, normal, acheck_for) {} - - jit_call *check_for (void) const - { - return static_cast<jit_call *> (argument (2)); - } - - virtual std::ostream& print (std::ostream& os, size_t indent = 0) const - { - print_indent (os, indent) << "error_check " << *check_for () << ", "; - print_successor (os, 1) << ", "; - return print_successor (os, 0); - } - - JIT_VALUE_ACCEPT; -protected: - virtual bool check_alive (size_t idx) const - { - return idx == 1 ? true : check_for ()->can_error (); - } -}; - -class -jit_extract_argument : public jit_assign_base -{ -public: - jit_extract_argument (jit_type *atype, jit_variable *adest) - : jit_assign_base (adest) - { - stash_type (atype); - } - - const std::string& name (void) const - { - return dest ()->name (); - } - - const jit_function& overload (void) const - { - return jit_typeinfo::cast (type (), jit_typeinfo::get_any ()); - } - - virtual std::ostream& print (std::ostream& os, size_t indent = 0) const - { - print_indent (os, indent); - - return short_print (os) << " = extract " << name (); - } - - JIT_VALUE_ACCEPT; -}; - -class -jit_store_argument : public jit_instruction -{ -public: - jit_store_argument (jit_variable *var) - : jit_instruction (var), dest (var) - {} - - const std::string& name (void) const - { - return dest->name (); - } - - const jit_function& overload (void) const - { - return jit_typeinfo::cast (jit_typeinfo::get_any (), result_type ()); - } - - jit_value *result (void) const - { - return argument (0); - } - - jit_type *result_type (void) const - { - return result ()->type (); - } - - llvm::Value *result_llvm (void) const - { - return result ()->to_llvm (); - } - - virtual std::ostream& print (std::ostream& os, size_t indent = 0) const - { - jit_value *res = result (); - print_indent (os, indent) << "store "; - dest->short_print (os); - - if (! isa<jit_variable> (res)) - { - os << " = "; - res->short_print (os); - } - - return os; - } - - JIT_VALUE_ACCEPT; -private: - jit_variable *dest; -}; - -class -jit_ir_walker -{ -public: - virtual ~jit_ir_walker () {} - -#define JIT_METH(clname) \ - virtual void visit (jit_ ## clname&) = 0; - - JIT_VISIT_IR_CLASSES; - -#undef JIT_METH -}; - -template <typename T, jit_type *(*EXTRACT_T)(void), typename PASS_T, bool QUOTE> -void -jit_const<T, EXTRACT_T, PASS_T, QUOTE>::accept (jit_ir_walker& walker) -{ - walker.visit (*this); -} - // convert between IRs // FIXME: Class relationships are messy from here on down. They need to be // cleaned up. @@ -2404,12 +446,5 @@ type_bound_vector bounds; }; -// some #defines we use in the header, but not the cc file -#undef JIT_VISIT_IR_CLASSES -#undef JIT_VISIT_IR_CONST -#undef JIT_VALUE_ACCEPT -#undef ASSIGN_ARG -#undef JIT_EXPAND - #endif #endif