Expression templates
=
Templates that
represent expressions
float a, b, c, d;
a = b + c + d;
float a, b, c, d;
a = b + c + d;
a = (b + c) + d;
Vec a, b, c, d;
a = b + c + d;
a = (b + c) + d;
^^^^^^^
temporary
Vec a, b, c, d;
a = b + c + d;
a = (b + c) + d;
^^^^^^^
temporary
tmp1.x = b.x + c.x;
tmp1.y = b.y + c.y;
tmp1.z = b.z + c.z;
tmp2.x = tmp1.x + d.x;
tmp2.y = tmp1.y + d.y;
tmp2.z = tmp1.z + d.z;
a.x = tmp2.x;
a.y = tmp2.y;
a.z = tmp2.z;
class Vec {
Vec
operator+ (const Vec& other) const
{
return Vector (
x+other.x,
y+other.y,
z+other.z);
}
void
operator= (const Vec& other)
{
x = other.x;
y = other.y;
z = other.z;
}
};
// Faster way
a.x = b.x + c.x + d.x;
a.y = b.y + c.y + d.y;
a.z = b.z + c.z + d.z;
Vec a, b, c, d;
a = b + c + d;
a = (b + c) + d;
^^^^^^^
temporary
a = ((b + c) + d);
a = (BinarySum(b, c) + d);
a = BinarySum(
BinarySum(b, c),
d
);
class BinarySum {
float eval(size_t i) {
return lhs.eval(i) + rhs.eval(i);
}
};
void Vec::operator= (const BinarySum& other) {
for (size_t i = 0; i < 3; ++i)
data_[i] = other.eval(i);
}
auto tmp1 = BinarySum(b, c);
auto tmp2 = BinarySum(tmp1, d);
a = tmp2;
auto tmp1 = BinarySum(b, c);
auto tmp2 = BinarySum(tmp1, d);
a = tmp2;
for (size_t i = 0; i < 3; ++i) {
a[i] = tmp2.eval(i);
}
auto tmp1 = BinarySum(b, c);
auto tmp2 = BinarySum(tmp1, d);
a = tmp2;
for (size_t i = 0; i < 3; ++i) {
a[i] = tmp2.eval(i);
a[i] = tmp1.eval(i) + d[i];
}
auto tmp1 = BinarySum(b, c);
auto tmp2 = BinarySum(tmp1, d);
a = tmp2;
for (size_t i = 0; i < 3; ++i) {
a[i] = tmp2.eval(i);
a[i] = tmp1.eval(i) + d[i];
a[i] = b[i] + c[i] + d[i];
}
UnaryOp<Op, Type>;
// sqrt, exp, neg, ...
BinaryOp<Op, TypeL, TypeR>;
// dot product, /, ...
CwiseBinaryOp<
Op, TypeL, TypeR
>;
// +, -, cwise prod, ...
UnaryOp<Op, Type>;
// sqrt, exp, neg, ...
BinaryOp<Op, TypeL, TypeR>;
// dot product, /, ...
CwiseBinaryOp<
Op, TypeL, TypeR
>;
// +, -, cwise prod, ...
VectorXf v, w;
VectorXf u = v + w;
VectorXf u = v + w;
MatrixBase::operator+(const MatrixBase&) ->
CwiseBinaryOp<
internal::scalar_sum_op<float>,
VectorXf, VectorXf
>;
// Why not
// MatrixBase::operator+(const CWiseBinaryOp<...> &);
// CwiseBinaryOp is a type
// Don't want to redefine all operations
// for every possible combinations.
// CRTP
class CwiseBinaryOp : MatrixBase<CWiseBinaryOp>;
VectorXf u = CwiseBinaryOp<...>(v, w);
Matrix& operator=(const MatrixBase<OtherDerived>& other) {
return Base::operator=(other.derived());
}
Derived&
MatrixBase::operator=(const MatrixBase<OtherDerived>& other);
// With
Derived = VectorXf;
OtherDerived = CwiseBinaryOp<...>;
VectorXf u = CwiseBinaryOp<...>(v, w);
// ASSIGNMENT STRATEGY SELECTION
Derived&
MatrixBase::operator=(const MatrixBase<OtherDerived>& other) {
return internal::assign_selector<Derived,OtherDerived>::
run(derived(), other.derived());
}
template<typename Derived, typename OtherDerived,
bool EvalBeforeAssigning = int(OtherDerived::Flags) &
EvalBeforeAssigningBit,
bool NeedToTranspose = ...>
struct internal::assign_selector;
// ASSIGNMENT STRATEGY SELECTION
template<typename Derived, typename OtherDerived>
struct internal::assign_selector<Derived,OtherDerived,false,false>
{
static Derived& run(Derived& dst, const OtherDerived& other) {
return dst.lazyAssign(other.derived());
}
}
template<...>
Derived& MatrixBase<Derived>::lazyAssign(
const MatrixBase<OtherDerived>& other
) {
internal::assign_impl<Derived, OtherDerived>
::run(derived(),other.derived());
return derived();
}
// ASSIGNMENT EXECUTION
struct internal::assign_impl<
Derived1, Derived2, LinearVectorization, NoUnrolling>
{
static void run(Derived1 &dst, const Derived2 &src) {
[...]
// Unaligned copy
for(int index = 0; index < alignedStart; index++)
dst.copyCoeff(index, src);
// Vectorized copy
for(int index = alignedStart; index < alignedEnd; index += packetSize) {
dst.template copyPacket<
Derived2, Aligned, internal::assign_traits<Derived1,Derived2
>::SrcAlignment>(index, src);
}
// Unaligned copy
for(int index = alignedEnd; index < size; index++)
dst.copyCoeff(index, src);
}
};
// VECTORIZED ASSIGNMENT EXECUTION
// Vector copy loop from prev. slide
for(int index = alignedStart; index < alignedEnd; index += packetSize) {
dst.template copyPacket<
Derived2, Aligned, internal::assign_traits<Derived1,Derived2
>::SrcAlignment>(index, src);
}
inline void MatrixBase<Derived>::copyPacket(
int index, const MatrixBase<OtherDerived>& other
) {
// STORING DATA
derived().template writePacket<StoreMode>(
index,
// LOADING DATA
other.derived().template packet<LoadMode>(index)
);
}
// STORING
template<int StoreMode>
inline void writePacket(int index, const PacketScalar& x)
{
internal::pstoret<Scalar, PacketScalar, StoreMode>
(m_storage.data() + index, x);
}
template<typename Scalar, typename Packet, int LoadMode>
inline void internal::pstoret(Scalar* to, const Packet& from)
{
if(LoadMode == Aligned)
internal::pstore(to, from);
else
internal::pstoreu(to, from);
}
template<> inline void internal::pstore(
float* to, const __m128& from
) {
_mm_store_ps(to, from);
}
// LOADING and SUMMING
class CwiseBinaryOp {
template<int LoadMode>
inline PacketScalar packet(int index) const {
// SUM
return m_functor.packetOp(
// LHS VECTOR LOAD
m_lhs.template packet<LoadMode>(index),
// RHS VECTOR LOAD
m_rhs.template packet<LoadMode>(index)
);
}
};
// LOADS are similar to STORES
// LOADING and SUMMING
class CwiseBinaryOp {
template<int LoadMode>
inline PacketScalar packet(int index) const {
// SUM
return m_functor.packetOp(...);
}
};
template<typename Scalar> struct internal::scalar_sum_op {
template<typename PacketScalar>
inline const PacketScalar packetOp(
const PacketScalar& a, const PacketScalar& b
) const {
return internal::padd(a,b);
}
};
template<> inline __m128 internal::padd (
const __m128& a, const __m128& b
) { return _mm_add_ps(a,b); }
VectorXf u = v + w;
// Has now become
// Unaligned copy
for(int index = 0; index < alignedStart; index++)
dst.copyCoeff(index, src);
// Vectorized copy
for(int index = alignedStart; index < alignedEnd; index += packetSize) {
// _mm_store_ps, _mm_add_ps, _mm_load_ps
}
// Unaligned copy
for(int index = alignedEnd; index < size; index++)
dst.copyCoeff(index, src);