Mercurial > hg > octave-jordi
changeset 9743:26abff55f6fe
optimize bsxfun for common built-in operations
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Tue, 20 Oct 2009 10:47:22 +0200 |
parents | 9f8ff01abc65 |
children | fb3543975ed9 |
files | liboctave/CNDArray.cc liboctave/CNDArray.h liboctave/ChangeLog liboctave/Makefile.in liboctave/boolNDArray.cc liboctave/boolNDArray.h liboctave/dNDArray.cc liboctave/dNDArray.h liboctave/dim-vector.h liboctave/fCNDArray.cc liboctave/fCNDArray.h liboctave/fNDArray.cc liboctave/fNDArray.h liboctave/int16NDArray.cc liboctave/int16NDArray.h liboctave/int32NDArray.cc liboctave/int32NDArray.h liboctave/int64NDArray.cc liboctave/int64NDArray.h liboctave/int8NDArray.cc liboctave/int8NDArray.h liboctave/mx-inlines.cc liboctave/oct-inttypes.h liboctave/uint16NDArray.cc liboctave/uint16NDArray.h liboctave/uint32NDArray.cc liboctave/uint32NDArray.h liboctave/uint64NDArray.cc liboctave/uint64NDArray.h liboctave/uint8NDArray.cc liboctave/uint8NDArray.h src/ChangeLog src/DLD-FUNCTIONS/bsxfun.cc src/ov-base.h |
diffstat | 34 files changed, 379 insertions(+), 39 deletions(-) [+] |
line wrap: on
line diff
--- a/liboctave/CNDArray.cc +++ b/liboctave/CNDArray.cc @@ -43,6 +43,8 @@ #include "oct-fftw.h" #include "oct-locbuf.h" +#include "bsxfun-defs.cc" + ComplexNDArray::ComplexNDArray (const charNDArray& a) : MArrayN<Complex> (a.dims ()) { @@ -1104,6 +1106,9 @@ return a; } +BSXFUN_STDOP_DEFS_MXLOOP (ComplexNDArray) +BSXFUN_STDREL_DEFS_MXLOOP (ComplexNDArray) + /* ;;; Local Variables: *** ;;; mode: C++ ***
--- a/liboctave/CNDArray.h +++ b/liboctave/CNDArray.h @@ -28,6 +28,7 @@ #include "mx-defs.h" #include "mx-op-decl.h" +#include "bsxfun-decl.h" class OCTAVE_API @@ -185,6 +186,9 @@ extern OCTAVE_API ComplexNDArray& operator *= (ComplexNDArray& a, double s); extern OCTAVE_API ComplexNDArray& operator /= (ComplexNDArray& a, double s); +BSXFUN_STDOP_DECLS (ComplexNDArray, OCTAVE_API) +BSXFUN_STDREL_DECLS (ComplexNDArray, OCTAVE_API) + #endif /*
--- a/liboctave/ChangeLog +++ b/liboctave/ChangeLog @@ -1,3 +1,44 @@ +2009-10-20 Jaroslav Hajek <highegg@gmail.com> + + * bsxfun-decl.h, bsxfun-defs.cc: New sources. + * Makefile.in: Add them. + * dim-vector.h (dim_vector::compute_index, + dim_vector::increment_index): Fix. + * mx-inlines.cc (DEFMXMAPPER, DEFMXLOCALMAPPER): New macros. + (mx_inline_xmin, mx_inline_xmax): New loops. + (mx_inline_fun): Remove. + * oct-inttypes.h (xmin (const octave_int<T>&, const octave_int<T>&), + xmax (const octave_int<T>&, const octave_int<T>&)): + New inline functions. + + * CNDArray.cc: Define bsxfun operations. + * boolNDArray.cc: Ditto. + * dNDArray.cc: Ditto. + * fCNDArray.cc: Ditto. + * fNDArray.cc: Ditto. + * int16NDArray.cc: Ditto. + * int32NDArray.cc: Ditto. + * int64NDArray.cc: Ditto. + * int8NDArray.cc: Ditto. + * uint16NDArray.cc: Ditto. + * uint32NDArray.cc: Ditto. + * uint64NDArray.cc: Ditto. + * uint8NDArray.cc: Ditto. + + * CNDArray.h: Declare bsxfun operations. + * boolNDArray.h: Ditto. + * dNDArray.h: Ditto. + * fCNDArray.h: Ditto. + * fNDArray.h: Ditto. + * int16NDArray.h: Ditto. + * int32NDArray.h: Ditto. + * int64NDArray.h: Ditto. + * int8NDArray.h: Ditto. + * uint16NDArray.h: Ditto. + * uint32NDArray.h: Ditto. + * uint64NDArray.h: Ditto. + * uint8NDArray.h: Ditto. + 2009-10-19 Jaroslav Hajek <highegg@gmail.com> * dim-vector.h (dim_vector::compute_index,
--- a/liboctave/Makefile.in +++ b/liboctave/Makefile.in @@ -50,7 +50,8 @@ MATRIX_INC := Array.h Array2.h Array3.h ArrayN.h DiagArray2.h \ Array-util.h MArray-decl.h MArray-defs.h \ MArray.h MArray2.h MDiagArray2.h Matrix.h MArrayN.h \ - base-lu.h base-qr.h base-aepbal.h dim-vector.h mx-base.h mx-op-decl.h \ + base-lu.h base-qr.h base-aepbal.h bsxfun-decl.h dim-vector.h \ + mx-base.h mx-op-decl.h \ mx-op-defs.h mx-defs.h mx-ext.h CColVector.h CDiagMatrix.h \ CMatrix.h CNDArray.h CRowVector.h CmplxAEPBAL.h CmplxCHOL.h \ CmplxGEPBAL.h CmplxHESS.h CmplxLU.h CmplxQR.h CmplxQRP.h \ @@ -110,7 +111,7 @@ $(VX_OP_INC) \ $(SPARSE_MX_OP_INC) -TEMPLATE_SRC := Array.cc eigs-base.cc DiagArray2.cc \ +TEMPLATE_SRC := Array.cc bsxfun-defs.cc eigs-base.cc DiagArray2.cc \ MArray.cc MArray2.cc MArrayN.cc MDiagArray2.cc \ base-lu.cc base-qr.cc oct-sort.cc sparse-base-lu.cc \ sparse-base-chol.cc sparse-dmsolve.cc
--- a/liboctave/boolNDArray.cc +++ b/liboctave/boolNDArray.cc @@ -34,6 +34,8 @@ #include "mx-op-defs.h" #include "MArray-defs.h" +#include "bsxfun-defs.cc" + // unary operations boolNDArray @@ -184,6 +186,9 @@ return a; } +BSXFUN_OP_DEF_MXLOOP (and, boolNDArray, mx_inline_and) +BSXFUN_OP_DEF_MXLOOP (or, boolNDArray, mx_inline_or) + /* ;;; Local Variables: *** ;;; mode: C++ ***
--- a/liboctave/boolNDArray.h +++ b/liboctave/boolNDArray.h @@ -27,9 +27,11 @@ #include "mx-defs.h" #include "mx-op-decl.h" +#include "bsxfun-decl.h" #include "boolMatrix.h" + class OCTAVE_API boolNDArray : public Array<bool> @@ -137,6 +139,9 @@ extern OCTAVE_API boolNDArray& mx_el_or_assign (boolNDArray& m, const boolNDArray& a); +BSXFUN_OP_DECL (and, boolNDArray, OCTAVE_API); +BSXFUN_OP_DECL (or, boolNDArray, OCTAVE_API); + #endif /*
--- a/liboctave/dNDArray.cc +++ b/liboctave/dNDArray.cc @@ -43,6 +43,8 @@ #include "oct-fftw.h" #include "oct-locbuf.h" +#include "bsxfun-defs.cc" + NDArray::NDArray (const Array<octave_idx_type>& a, bool zero_based, bool negative_to_nan) { @@ -1132,6 +1134,9 @@ NDND_CMP_OPS (NDArray, NDArray) NDND_BOOL_OPS (NDArray, NDArray) +BSXFUN_STDOP_DEFS_MXLOOP (NDArray) +BSXFUN_STDREL_DEFS_MXLOOP (NDArray) + /* ;;; Local Variables: *** ;;; mode: C++ ***
--- a/liboctave/dNDArray.h +++ b/liboctave/dNDArray.h @@ -30,6 +30,7 @@ #include "mx-defs.h" #include "mx-op-decl.h" +#include "bsxfun-decl.h" class OCTAVE_API @@ -195,6 +196,9 @@ MARRAY_FORWARD_DEFS (MArrayN, NDArray, double) +BSXFUN_STDOP_DECLS (NDArray, OCTAVE_API) +BSXFUN_STDREL_DECLS (NDArray, OCTAVE_API) + #endif /*
--- a/liboctave/dim-vector.h +++ b/liboctave/dim-vector.h @@ -520,7 +520,7 @@ octave_idx_type compute_index (const octave_idx_type *idx) { octave_idx_type k = 0; - for (int i = length () - 1; i >= 0; i++) + for (int i = length () - 1; i >= 0; i--) k = k * rep[i] + idx[i]; return k; @@ -535,7 +535,7 @@ for (i = start; i < length (); i++) { if (++(*idx) == rep[i]) - *idx = 0; + *idx++ = 0; else break; }
--- a/liboctave/fCNDArray.cc +++ b/liboctave/fCNDArray.cc @@ -43,6 +43,8 @@ #include "oct-fftw.h" #include "oct-locbuf.h" +#include "bsxfun-defs.cc" + FloatComplexNDArray::FloatComplexNDArray (const charNDArray& a) : MArrayN<FloatComplex> (a.dims ()) { @@ -1099,6 +1101,9 @@ return a; } +BSXFUN_STDOP_DEFS_MXLOOP (FloatComplexNDArray) +BSXFUN_STDREL_DEFS_MXLOOP (FloatComplexNDArray) + /* ;;; Local Variables: *** ;;; mode: C++ ***
--- a/liboctave/fCNDArray.h +++ b/liboctave/fCNDArray.h @@ -28,6 +28,7 @@ #include "mx-defs.h" #include "mx-op-decl.h" +#include "bsxfun-decl.h" class OCTAVE_API @@ -185,6 +186,9 @@ extern OCTAVE_API FloatComplexNDArray& operator *= (FloatComplexNDArray& a, float s); extern OCTAVE_API FloatComplexNDArray& operator /= (FloatComplexNDArray& a, float s); +BSXFUN_STDOP_DECLS (FloatComplexNDArray, OCTAVE_API) +BSXFUN_STDREL_DECLS (FloatComplexNDArray, OCTAVE_API) + #endif /*
--- a/liboctave/fNDArray.cc +++ b/liboctave/fNDArray.cc @@ -43,6 +43,8 @@ #include "oct-fftw.h" #include "oct-locbuf.h" +#include "bsxfun-defs.cc" + FloatNDArray::FloatNDArray (const charNDArray& a) : MArrayN<float> (a.dims ()) { @@ -1090,6 +1092,9 @@ NDND_CMP_OPS (FloatNDArray, FloatNDArray) NDND_BOOL_OPS (FloatNDArray, FloatNDArray) +BSXFUN_STDOP_DEFS_MXLOOP (FloatNDArray) +BSXFUN_STDREL_DEFS_MXLOOP (FloatNDArray) + /* ;;; Local Variables: *** ;;; mode: C++ ***
--- a/liboctave/fNDArray.h +++ b/liboctave/fNDArray.h @@ -30,6 +30,7 @@ #include "mx-defs.h" #include "mx-op-decl.h" +#include "bsxfun-decl.h" class OCTAVE_API @@ -192,6 +193,9 @@ MARRAY_FORWARD_DEFS (MArrayN, FloatNDArray, float) +BSXFUN_STDOP_DECLS (FloatNDArray, OCTAVE_API) +BSXFUN_STDREL_DECLS (FloatNDArray, OCTAVE_API) + #endif /*
--- a/liboctave/int16NDArray.cc +++ b/liboctave/int16NDArray.cc @@ -29,6 +29,8 @@ #include "mx-op-defs.h" #include "intNDArray.cc" +#include "bsxfun-defs.cc" + template class OCTAVE_API intNDArray<octave_int16>; template OCTAVE_API @@ -50,6 +52,9 @@ MINMAX_FCNS (int16) +BSXFUN_STDOP_DEFS_MXLOOP (int16NDArray) +BSXFUN_STDREL_DEFS_MXLOOP (int16NDArray) + /* ;;; Local Variables: *** ;;; mode: C++ ***
--- a/liboctave/int16NDArray.h +++ b/liboctave/int16NDArray.h @@ -26,6 +26,7 @@ #include "intNDArray.h" #include "mx-op-decl.h" #include "oct-inttypes.h" +#include "bsxfun-decl.h" typedef intNDArray<octave_int16> int16NDArray; @@ -42,6 +43,9 @@ MINMAX_DECLS (int16) +BSXFUN_STDOP_DECLS (int16NDArray, OCTAVE_API) +BSXFUN_STDREL_DECLS (int16NDArray, OCTAVE_API) + #endif /*
--- a/liboctave/int32NDArray.cc +++ b/liboctave/int32NDArray.cc @@ -29,6 +29,8 @@ #include "mx-op-defs.h" #include "intNDArray.cc" +#include "bsxfun-defs.cc" + template class OCTAVE_API intNDArray<octave_int32>; template OCTAVE_API @@ -50,6 +52,9 @@ MINMAX_FCNS (int32) +BSXFUN_STDOP_DEFS_MXLOOP (int32NDArray) +BSXFUN_STDREL_DEFS_MXLOOP (int32NDArray) + /* ;;; Local Variables: *** ;;; mode: C++ ***
--- a/liboctave/int32NDArray.h +++ b/liboctave/int32NDArray.h @@ -26,6 +26,7 @@ #include "intNDArray.h" #include "mx-op-decl.h" #include "oct-inttypes.h" +#include "bsxfun-decl.h" typedef intNDArray<octave_int32> int32NDArray; @@ -42,6 +43,9 @@ MINMAX_DECLS (int32) +BSXFUN_STDOP_DECLS (int32NDArray, OCTAVE_API) +BSXFUN_STDREL_DECLS (int32NDArray, OCTAVE_API) + #endif /*
--- a/liboctave/int64NDArray.cc +++ b/liboctave/int64NDArray.cc @@ -29,6 +29,8 @@ #include "mx-op-defs.h" #include "intNDArray.cc" +#include "bsxfun-defs.cc" + template class OCTAVE_API intNDArray<octave_int64>; template OCTAVE_API @@ -50,6 +52,9 @@ MINMAX_FCNS (int64) +BSXFUN_STDOP_DEFS_MXLOOP (int64NDArray) +BSXFUN_STDREL_DEFS_MXLOOP (int64NDArray) + /* ;;; Local Variables: *** ;;; mode: C++ ***
--- a/liboctave/int64NDArray.h +++ b/liboctave/int64NDArray.h @@ -26,6 +26,7 @@ #include "intNDArray.h" #include "mx-op-decl.h" #include "oct-inttypes.h" +#include "bsxfun-decl.h" typedef intNDArray<octave_int64> int64NDArray; @@ -42,6 +43,9 @@ MINMAX_DECLS (int64) +BSXFUN_STDOP_DECLS (int64NDArray, OCTAVE_API) +BSXFUN_STDREL_DECLS (int64NDArray, OCTAVE_API) + #endif /*
--- a/liboctave/int8NDArray.cc +++ b/liboctave/int8NDArray.cc @@ -29,6 +29,8 @@ #include "mx-op-defs.h" #include "intNDArray.cc" +#include "bsxfun-defs.cc" + template class OCTAVE_API intNDArray<octave_int8>; template OCTAVE_API @@ -50,6 +52,9 @@ MINMAX_FCNS (int8) +BSXFUN_STDOP_DEFS_MXLOOP (int8NDArray) +BSXFUN_STDREL_DEFS_MXLOOP (int8NDArray) + /* ;;; Local Variables: *** ;;; mode: C++ ***
--- a/liboctave/int8NDArray.h +++ b/liboctave/int8NDArray.h @@ -26,6 +26,7 @@ #include "intNDArray.h" #include "mx-op-decl.h" #include "oct-inttypes.h" +#include "bsxfun-decl.h" typedef intNDArray<octave_int8> int8NDArray; @@ -42,6 +43,9 @@ MINMAX_DECLS (int8) +BSXFUN_STDOP_DECLS (int8NDArray, OCTAVE_API) +BSXFUN_STDREL_DECLS (int8NDArray, OCTAVE_API) + #endif /*
--- a/liboctave/mx-inlines.cc +++ b/liboctave/mx-inlines.cc @@ -197,39 +197,28 @@ DEFMXANYNAN(Complex) DEFMXANYNAN(FloatComplex) -// Arbitrary unary/binary function mappers. Note the function reference is a -// template parameter! -template <class R, class X, R F(X)> -void mx_inline_fun (size_t n, R *r, const X *x) -{ for (size_t i = 0; i < n; i++) r[i] = F(x[i]); } - -template <class R, class X, R F(const X&)> -void mx_inline_fun (size_t n, R *r, const X *x) -{ for (size_t i = 0; i < n; i++) r[i] = F(x[i]); } - -template <class R, class X, class Y, R F(X, Y)> -void mx_inline_fun (size_t n, R *r, const X *x, const Y *y) -{ for (size_t i = 0; i < n; i++) r[i] = F(x[i], y[i]); } +// Pairwise minimums/maximums +#define DEFMXMAPPER(F, FUN) \ +template <class T> \ +inline void F (size_t n, T *r, const T *x, const T *y) \ +{ for (size_t i = 0; i < n; i++) r[i] = FUN (x[i], y[i]); } \ +template <class T> \ +inline void F (size_t n, T *r, const T *x, T y) \ +{ for (size_t i = 0; i < n; i++) r[i] = FUN (x[i], y); } \ +template <class T> \ +inline void F (size_t n, T *r, T x, const T *y) \ +{ for (size_t i = 0; i < n; i++) r[i] = FUN (x, y[i]); } -template <class R, class X, class Y, R F(X, Y)> -void mx_inline_fun (size_t n, R *r, X x, const Y *y) -{ for (size_t i = 0; i < n; i++) r[i] = F(x, y[i]); } - -template <class R, class X, class Y, R F(X, Y)> -void mx_inline_fun (size_t n, R *r, const X *x, Y y) -{ for (size_t i = 0; i < n; i++) r[i] = F(x[i], y); } +DEFMXMAPPER (mx_inline_xmin, xmin) +DEFMXMAPPER (mx_inline_xmax, xmax) -template <class R, class X, class Y, R F(const X&, const Y&)> -void mx_inline_fun (size_t n, R *r, const X *x, const Y *y) -{ for (size_t i = 0; i < n; i++) r[i] = F(x[i], y[i]); } - -template <class R, class X, class Y, R F(const X&, const Y&)> -void mx_inline_fun (size_t n, R *r, X x, const Y *y) -{ for (size_t i = 0; i < n; i++) r[i] = F(x, y[i]); } - -template <class R, class X, class Y, R F(const X&, const Y&)> -void mx_inline_fun (size_t n, R *r, const X *x, Y y) -{ for (size_t i = 0; i < n; i++) r[i] = F(x[i], y); } +#define DEFMXLOCALMAPPER(F, FUN, T) \ +static void F (size_t n, T *r, const T *x, const T *y) \ +{ for (size_t i = 0; i < n; i++) r[i] = FUN (x[i], y[i]); } \ +static void F (size_t n, T *r, const T *x, T y) \ +{ for (size_t i = 0; i < n; i++) r[i] = FUN (x[i], y); } \ +static void F (size_t n, T *r, T x, const T *y) \ +{ for (size_t i = 0; i < n; i++) r[i] = FUN (x, y[i]); } // Appliers. Since these call the operation just once, we pass it as // a pointer, to allow the compiler reduce number of instances.
--- a/liboctave/oct-inttypes.h +++ b/liboctave/oct-inttypes.h @@ -1090,6 +1090,22 @@ #undef OCTAVE_INT_FLOAT_CMP_OP +template <class T> +octave_int<T> +xmax (const octave_int<T>& x, const octave_int<T>& y) +{ + const T xv = x.value (), yv = y.value (); + return octave_int<T> (xv >= yv ? xv : yv); +} + +template <class T> +octave_int<T> +xmin (const octave_int<T>& x, const octave_int<T>& y) +{ + const T xv = x.value (), yv = y.value (); + return octave_int<T> (xv <= yv ? xv : yv); +} + #endif /*
--- a/liboctave/uint16NDArray.cc +++ b/liboctave/uint16NDArray.cc @@ -29,6 +29,8 @@ #include "mx-op-defs.h" #include "intNDArray.cc" +#include "bsxfun-defs.cc" + template class OCTAVE_API intNDArray<octave_uint16>; template OCTAVE_API @@ -50,6 +52,9 @@ MINMAX_FCNS (uint16) +BSXFUN_STDOP_DEFS_MXLOOP (uint16NDArray) +BSXFUN_STDREL_DEFS_MXLOOP (uint16NDArray) + /* ;;; Local Variables: *** ;;; mode: C++ ***
--- a/liboctave/uint16NDArray.h +++ b/liboctave/uint16NDArray.h @@ -26,6 +26,7 @@ #include "intNDArray.h" #include "mx-op-decl.h" #include "oct-inttypes.h" +#include "bsxfun-decl.h" typedef intNDArray<octave_uint16> uint16NDArray; @@ -42,6 +43,9 @@ MINMAX_DECLS (uint16) +BSXFUN_STDOP_DECLS (uint16NDArray, OCTAVE_API) +BSXFUN_STDREL_DECLS (uint16NDArray, OCTAVE_API) + #endif /*
--- a/liboctave/uint32NDArray.cc +++ b/liboctave/uint32NDArray.cc @@ -29,6 +29,8 @@ #include "mx-op-defs.h" #include "intNDArray.cc" +#include "bsxfun-defs.cc" + template class OCTAVE_API intNDArray<octave_uint32>; template OCTAVE_API @@ -50,6 +52,9 @@ MINMAX_FCNS (uint32) +BSXFUN_STDOP_DEFS_MXLOOP (uint32NDArray) +BSXFUN_STDREL_DEFS_MXLOOP (uint32NDArray) + /* ;;; Local Variables: *** ;;; mode: C++ ***
--- a/liboctave/uint32NDArray.h +++ b/liboctave/uint32NDArray.h @@ -26,6 +26,7 @@ #include "intNDArray.h" #include "mx-op-decl.h" #include "oct-inttypes.h" +#include "bsxfun-decl.h" typedef intNDArray<octave_uint32> uint32NDArray; @@ -42,6 +43,9 @@ MINMAX_DECLS (uint32) +BSXFUN_STDOP_DECLS (uint32NDArray, OCTAVE_API) +BSXFUN_STDREL_DECLS (uint32NDArray, OCTAVE_API) + #endif /*
--- a/liboctave/uint64NDArray.cc +++ b/liboctave/uint64NDArray.cc @@ -29,6 +29,8 @@ #include "mx-op-defs.h" #include "intNDArray.cc" +#include "bsxfun-defs.cc" + template class OCTAVE_API intNDArray<octave_uint64>; template OCTAVE_API @@ -50,6 +52,9 @@ MINMAX_FCNS (uint64) +BSXFUN_STDOP_DEFS_MXLOOP (uint64NDArray) +BSXFUN_STDREL_DEFS_MXLOOP (uint64NDArray) + /* ;;; Local Variables: *** ;;; mode: C++ ***
--- a/liboctave/uint64NDArray.h +++ b/liboctave/uint64NDArray.h @@ -26,6 +26,7 @@ #include "intNDArray.h" #include "mx-op-decl.h" #include "oct-inttypes.h" +#include "bsxfun-decl.h" typedef intNDArray<octave_uint64> uint64NDArray; @@ -42,6 +43,9 @@ MINMAX_DECLS (uint64) +BSXFUN_STDOP_DECLS (uint64NDArray, OCTAVE_API) +BSXFUN_STDREL_DECLS (uint64NDArray, OCTAVE_API) + #endif /*
--- a/liboctave/uint8NDArray.cc +++ b/liboctave/uint8NDArray.cc @@ -29,6 +29,8 @@ #include "mx-op-defs.h" #include "intNDArray.cc" +#include "bsxfun-defs.cc" + template class OCTAVE_API intNDArray<octave_uint8>; template OCTAVE_API @@ -50,6 +52,9 @@ MINMAX_FCNS (uint8) +BSXFUN_STDOP_DEFS_MXLOOP (uint8NDArray) +BSXFUN_STDREL_DEFS_MXLOOP (uint8NDArray) + /* ;;; Local Variables: *** ;;; mode: C++ ***
--- a/liboctave/uint8NDArray.h +++ b/liboctave/uint8NDArray.h @@ -26,6 +26,7 @@ #include "intNDArray.h" #include "mx-op-decl.h" #include "oct-inttypes.h" +#include "bsxfun-decl.h" typedef intNDArray<octave_uint8> uint8NDArray; @@ -42,6 +43,9 @@ MINMAX_DECLS (uint8) +BSXFUN_STDOP_DECLS (uint8NDArray, OCTAVE_API) +BSXFUN_STDREL_DECLS (uint8NDArray, OCTAVE_API) + #endif /*
--- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,3 +1,14 @@ +2009-10-20 Jaroslav Hajek <highegg@gmail.com> + + * ov-base.h (builtin_type_t): Declare also btyp_num_types. + * DLD-FUNCTIONS/bsxfun.cc (bsxfun_builtin_op): New enum. + (bsxfun_handler): New typedef. + (bsxfun_builtin_names, bsxfun_handler_table): New variables. + (bsxfun_builtin_lookup, maybe_fill_table, maybe_optimized_builtin): + New static funcs. + (bsxfun_forward_op, bsxfun_forward_rel): New static template funcs. + (Fbsxfun): Try to optimize some built-in operations. + 2009-10-19 Jaroslav Hajek <highegg@gmail.com> * DLD-FUNCTIONS/cellfun.cc (Fcellslices): Allow non-positive indices
--- a/src/DLD-FUNCTIONS/bsxfun.cc +++ b/src/DLD-FUNCTIONS/bsxfun.cc @@ -36,6 +36,154 @@ #include "variables.h" #include "ov-colon.h" #include "unwind-prot.h" +#include "ov-fcn-handle.h" + +// Optimized bsxfun operations +enum bsxfun_builtin_op +{ + bsxfun_builtin_plus = 0, + bsxfun_builtin_minus, + bsxfun_builtin_times, + bsxfun_builtin_divide, + bsxfun_builtin_max, + bsxfun_builtin_min, + bsxfun_builtin_eq, + bsxfun_builtin_ne, + bsxfun_builtin_lt, + bsxfun_builtin_le, + bsxfun_builtin_gt, + bsxfun_builtin_ge, + bsxfun_builtin_and, + bsxfun_builtin_or, + bsxfun_builtin_unknown, + bsxfun_num_builtin_ops = bsxfun_builtin_unknown +}; + +const char *bsxfun_builtin_names[] = +{ + "plus", + "minus", + "times", + "rdivide", + "max", + "min", + "eq", + "ne", + "lt", + "le", + "gt", + "ge", + "and", + "or" +}; + +static bsxfun_builtin_op +bsxfun_builtin_lookup (const std::string& name) +{ + for (int i = 0; i < bsxfun_num_builtin_ops; i++) + if (name == bsxfun_builtin_names[i]) + return static_cast<bsxfun_builtin_op> (i); + return bsxfun_builtin_unknown; +} + +typedef octave_value (*bsxfun_handler) (const octave_value&, const octave_value&); + +// Static table of handlers. +bsxfun_handler bsxfun_handler_table[bsxfun_num_builtin_ops][btyp_num_types]; + +template <class NDA, NDA (bsxfun_op) (const NDA&, const NDA&)> +static octave_value +bsxfun_forward_op (const octave_value& x, const octave_value& y) +{ + NDA xa = octave_value_extract<NDA> (x); + NDA ya = octave_value_extract<NDA> (y); + return octave_value (bsxfun_op (xa, ya)); +} + +template <class NDA, boolNDArray (bsxfun_rel) (const NDA&, const NDA&)> +static octave_value +bsxfun_forward_rel (const octave_value& x, const octave_value& y) +{ + NDA xa = octave_value_extract<NDA> (x); + NDA ya = octave_value_extract<NDA> (y); + return octave_value (bsxfun_rel (xa, ya)); +} + +static void maybe_fill_table (void) +{ + static bool filled = false; + if (filled) + return; + +#define REGISTER_OP_HANDLER(OP, BTYP, NDA, FUNOP) \ + bsxfun_handler_table[OP][BTYP] = bsxfun_forward_op<NDA, FUNOP> +#define REGISTER_REL_HANDLER(REL, BTYP, NDA, FUNREL) \ + bsxfun_handler_table[REL][BTYP] = bsxfun_forward_rel<NDA, FUNREL> +#define REGISTER_STD_HANDLERS(BTYP, NDA) \ + REGISTER_OP_HANDLER (bsxfun_builtin_plus, BTYP, NDA, bsxfun_add); \ + REGISTER_OP_HANDLER (bsxfun_builtin_minus, BTYP, NDA, bsxfun_sub); \ + REGISTER_OP_HANDLER (bsxfun_builtin_times, BTYP, NDA, bsxfun_mul); \ + REGISTER_OP_HANDLER (bsxfun_builtin_divide, BTYP, NDA, bsxfun_div); \ + REGISTER_OP_HANDLER (bsxfun_builtin_max, BTYP, NDA, bsxfun_max); \ + REGISTER_OP_HANDLER (bsxfun_builtin_min, BTYP, NDA, bsxfun_min); \ + REGISTER_REL_HANDLER (bsxfun_builtin_eq, BTYP, NDA, bsxfun_eq); \ + REGISTER_REL_HANDLER (bsxfun_builtin_ne, BTYP, NDA, bsxfun_ne); \ + REGISTER_REL_HANDLER (bsxfun_builtin_lt, BTYP, NDA, bsxfun_lt); \ + REGISTER_REL_HANDLER (bsxfun_builtin_le, BTYP, NDA, bsxfun_le); \ + REGISTER_REL_HANDLER (bsxfun_builtin_gt, BTYP, NDA, bsxfun_gt); \ + REGISTER_REL_HANDLER (bsxfun_builtin_ge, BTYP, NDA, bsxfun_ge) + + REGISTER_STD_HANDLERS (btyp_double, NDArray); + REGISTER_STD_HANDLERS (btyp_float, FloatNDArray); + REGISTER_STD_HANDLERS (btyp_complex, ComplexNDArray); + REGISTER_STD_HANDLERS (btyp_float_complex, FloatComplexNDArray); + REGISTER_STD_HANDLERS (btyp_int8, int8NDArray); + REGISTER_STD_HANDLERS (btyp_int16, int16NDArray); + REGISTER_STD_HANDLERS (btyp_int32, int32NDArray); + REGISTER_STD_HANDLERS (btyp_int64, int64NDArray); + REGISTER_STD_HANDLERS (btyp_uint8, uint8NDArray); + REGISTER_STD_HANDLERS (btyp_uint16, uint16NDArray); + REGISTER_STD_HANDLERS (btyp_uint32, uint32NDArray); + REGISTER_STD_HANDLERS (btyp_uint64, uint64NDArray); + + // For bools, we register and/or. + REGISTER_OP_HANDLER (bsxfun_builtin_and, btyp_bool, boolNDArray, bsxfun_and); + REGISTER_OP_HANDLER (bsxfun_builtin_or, btyp_bool, boolNDArray, bsxfun_or); +} + +static octave_value +maybe_optimized_builtin (const std::string& name, + const octave_value& a, const octave_value& b) +{ + octave_value retval; + + maybe_fill_table (); + + bsxfun_builtin_op op = bsxfun_builtin_lookup (name); + if (op != bsxfun_builtin_unknown) + { + builtin_type_t btyp_a = a.builtin_type (), btyp_b = b.builtin_type (); + + // Simplify single/double combinations. + if (btyp_a == btyp_float && btyp_b == btyp_double) + btyp_b = btyp_float; + else if (btyp_a == btyp_double && btyp_b == btyp_float) + btyp_a = btyp_float; + else if (btyp_a == btyp_float_complex && btyp_b == btyp_complex) + btyp_b = btyp_float_complex; + else if (btyp_a == btyp_complex && btyp_b == btyp_float_complex) + btyp_a = btyp_float_complex; + + if (btyp_a == btyp_b && btyp_a != btyp_unknown) + { + bsxfun_handler handler = bsxfun_handler_table[op][btyp_a]; + if (handler) + retval = handler (a, b); + } + } + + return retval; +} static bool maybe_update_column (octave_value& Ac, const octave_value& A, @@ -160,12 +308,27 @@ else if (! (args(0).is_function_handle () || args(0).is_inline_function ())) error ("bsxfun: first argument must be a string or function handle"); - if (! error_state) + const octave_value A = args (1); + const octave_value B = args (2); + + if (func.is_builtin_function () + || (func.is_function_handle () + && ! func.fcn_handle_value ()->is_overloaded () + && ! A.is_object () && ! B.is_object ())) + { + octave_function *fcn_val = func.function_value (); + if (fcn_val) + { + octave_value tmp = maybe_optimized_builtin (fcn_val->name (), A, B); + if (tmp.is_defined ()) + retval(0) = tmp; + } + } + + if (! error_state && retval.empty ()) { - const octave_value A = args (1); dim_vector dva = A.dims (); octave_idx_type nda = dva.length (); - const octave_value B = args (2); dim_vector dvb = B.dims (); octave_idx_type ndb = dvb.length (); octave_idx_type nd = nda;