Implement simd_builder for x86

ASMJIT-based tool for building vectorized loops (such as ones in BufferUtils.cpp)
This commit is contained in:
Nekotekina 2022-08-24 19:36:37 +03:00 committed by Ivan
parent 698c3415ea
commit e28707055b
5 changed files with 740 additions and 387 deletions

View file

@ -7,6 +7,7 @@
#include "mutex.h"
#include "util/vm.hpp"
#include "util/asm.hpp"
#include "util/v128.hpp"
#include <charconv>
#include <zlib.h>
@ -351,6 +352,424 @@ asmjit::inline_runtime::~inline_runtime()
utils::memory_protect(m_data, m_size, utils::protection::rx);
}
#if defined(ARCH_X64)
asmjit::simd_builder::simd_builder(CodeHolder* ch) noexcept
: native_asm(ch)
{
_init(true);
}
void asmjit::simd_builder::_init(bool full)
{
if (full && utils::has_avx512_icl())
{
v0 = x86::zmm0;
v1 = x86::zmm1;
v2 = x86::zmm2;
v3 = x86::zmm3;
v4 = x86::zmm4;
v5 = x86::zmm5;
vsize = 64;
}
else if (full && utils::has_avx2())
{
v0 = x86::ymm0;
v1 = x86::ymm1;
v2 = x86::ymm2;
v3 = x86::ymm3;
v4 = x86::ymm4;
v5 = x86::ymm5;
vsize = 32;
}
else
{
v0 = x86::xmm0;
v1 = x86::xmm1;
v2 = x86::xmm2;
v3 = x86::xmm3;
v4 = x86::xmm4;
v5 = x86::xmm5;
vsize = 16;
}
if (full && utils::has_avx512())
{
vmask = -1;
}
else
{
vmask = 0;
}
}
void asmjit::simd_builder::vec_cleanup_ret()
{
if (utils::has_avx() && vsize > 16)
this->vzeroupper();
this->ret();
}
void asmjit::simd_builder::vec_set_all_zeros(const Operand& v)
{
x86::Xmm reg(v.id());
if (utils::has_avx())
this->vpxor(reg, reg, reg);
else
this->xorps(reg, reg);
}
void asmjit::simd_builder::vec_set_all_ones(const Operand& v)
{
x86::Xmm reg(v.id());
if (x86::Zmm zr(v.id()); zr == v)
this->vpternlogd(zr, zr, zr, 0xff);
else if (x86::Ymm yr(v.id()); yr == v)
this->vpcmpeqd(yr, yr, yr);
else if (utils::has_avx())
this->vpcmpeqd(reg, reg, reg);
else
this->pcmpeqd(reg, reg);
}
void asmjit::simd_builder::vec_set_const(const Operand& v, const v128& val)
{
if (!val._u)
return vec_set_all_zeros(v);
if (!~val._u)
return vec_set_all_ones(v);
if (uptr(&val) < 0x8000'0000)
{
// Assume the constant comes from a code or data segment (unsafe)
if (x86::Zmm zr(v.id()); zr == v)
this->vbroadcasti32x4(zr, x86::oword_ptr(uptr(&val)));
else if (x86::Ymm yr(v.id()); yr == v)
this->vbroadcasti128(yr, x86::oword_ptr(uptr(&val)));
else if (utils::has_avx())
this->vmovaps(x86::Xmm(v.id()), x86::oword_ptr(uptr(&val)));
else
this->movaps(x86::Xmm(v.id()), x86::oword_ptr(uptr(&val)));
}
else
{
// TODO
fmt::throw_exception("Unexpected constant location");
}
}
void asmjit::simd_builder::vec_clobbering_test(u32 esize, const Operand& v, const Operand& rhs)
{
if (esize == 64)
{
this->emit(x86::Inst::kIdVptestmd, x86::k0, v, rhs);
this->ktestw(x86::k0, x86::k0);
}
else if (esize == 32)
{
this->emit(x86::Inst::kIdVptest, v, rhs);
}
else if (esize == 16 && utils::has_sse41())
{
this->emit(x86::Inst::kIdPtest, v, rhs);
}
else
{
if (v != rhs)
this->emit(x86::Inst::kIdPand, v, rhs);
if (esize == 16)
this->emit(x86::Inst::kIdPacksswb, v, v);
this->emit(x86::Inst::kIdMovq, x86::rax, v);
if (esize == 16 || esize == 8)
this->test(x86::rax, x86::rax);
else if (esize == 4)
this->test(x86::eax, x86::eax);
else if (esize == 2)
this->test(x86::ax, x86::ax);
else if (esize == 1)
this->test(x86::al, x86::al);
else
fmt::throw_exception("Unimplemented");
}
}
asmjit::x86::Mem asmjit::simd_builder::ptr_scale_for_vec(u32 esize, const x86::Gp& base, const x86::Gp& index)
{
switch (ensure(esize))
{
case 1: return x86::ptr(base, index, 0, 0);
case 2: return x86::ptr(base, index, 1, 0);
case 4: return x86::ptr(base, index, 2, 0);
case 8: return x86::ptr(base, index, 3, 0);
default: fmt::throw_exception("Bad esize");
}
}
void asmjit::simd_builder::vec_load_unaligned(u32 esize, const Operand& v, const x86::Mem& src)
{
ensure(std::has_single_bit(esize));
ensure(std::has_single_bit(vsize));
if (esize == 2)
{
ensure(vsize >= 2);
if (vsize == 2)
vec_set_all_zeros(v);
if (vsize == 2 && utils::has_avx())
this->emit(x86::Inst::kIdVpinsrw, x86::Xmm(v.id()), x86::Xmm(v.id()), src, Imm(0));
else if (vsize == 2)
this->emit(x86::Inst::kIdPinsrw, v, src, Imm(0));
else if (vmask && vmask < 8)
this->emit(x86::Inst::kIdVmovdqu16, v, src);
else
return vec_load_unaligned(vsize, v, src);
}
else if (esize == 4)
{
ensure(vsize >= 4);
if (vsize == 4 && utils::has_avx())
this->emit(x86::Inst::kIdVmovd, x86::Xmm(v.id()), src);
else if (vsize == 4)
this->emit(x86::Inst::kIdMovd, v, src);
else if (vmask && vmask < 8)
this->emit(x86::Inst::kIdVmovdqu32, v, src);
else
return vec_load_unaligned(vsize, v, src);
}
else if (esize == 8)
{
ensure(vsize >= 8);
if (vsize == 8 && utils::has_avx())
this->emit(x86::Inst::kIdVmovq, x86::Xmm(v.id()), src);
else if (vsize == 8)
this->emit(x86::Inst::kIdMovq, v, src);
else if (vmask && vmask < 8)
this->emit(x86::Inst::kIdVmovdqu64, v, src);
else
return vec_load_unaligned(vsize, v, src);
}
else if (esize >= 16)
{
ensure(vsize >= 16);
if (utils::has_avx())
this->emit(x86::Inst::kIdVmovdqu, v, src);
else
this->emit(x86::Inst::kIdMovups, v, src);
}
else
{
fmt::throw_exception("Unimplemented");
}
}
void asmjit::simd_builder::vec_store_unaligned(u32 esize, const Operand& v, const x86::Mem& dst)
{
ensure(std::has_single_bit(esize));
ensure(std::has_single_bit(vsize));
if (esize == 2)
{
ensure(vsize >= 2);
if (vsize == 2 && utils::has_avx())
this->emit(x86::Inst::kIdVpextrw, dst, x86::Xmm(v.id()), Imm(0));
else if (vsize == 2 && utils::has_sse41())
this->emit(x86::Inst::kIdPextrw, dst, v, Imm(0));
else if (vsize == 2)
this->push(x86::rax), this->pextrw(x86::eax, x86::Xmm(v.id()), 0), this->mov(dst, x86::ax), this->pop(x86::rax);
else if ((vmask && vmask < 8) || vsize >= 64)
this->emit(x86::Inst::kIdVmovdqu16, dst, v);
else
return vec_store_unaligned(vsize, v, dst);
}
else if (esize == 4)
{
ensure(vsize >= 4);
if (vsize == 4 && utils::has_avx())
this->emit(x86::Inst::kIdVmovd, dst, x86::Xmm(v.id()));
else if (vsize == 4)
this->emit(x86::Inst::kIdMovd, dst, v);
else if ((vmask && vmask < 8) || vsize >= 64)
this->emit(x86::Inst::kIdVmovdqu32, dst, v);
else
return vec_store_unaligned(vsize, v, dst);
}
else if (esize == 8)
{
ensure(vsize >= 8);
if (vsize == 8 && utils::has_avx())
this->emit(x86::Inst::kIdVmovq, dst, x86::Xmm(v.id()));
else if (vsize == 8)
this->emit(x86::Inst::kIdMovq, dst, v);
else if ((vmask && vmask < 8) || vsize >= 64)
this->emit(x86::Inst::kIdVmovdqu64, dst, v);
else
return vec_store_unaligned(vsize, v, dst);
}
else if (esize >= 16)
{
ensure(vsize >= 16);
if ((vmask && vmask < 8) || vsize >= 64)
this->emit(x86::Inst::kIdVmovdqu64, dst, v); // Not really needed
else if (utils::has_avx())
this->emit(x86::Inst::kIdVmovdqu, dst, v);
else
this->emit(x86::Inst::kIdMovups, dst, v);
}
else
{
fmt::throw_exception("Unimplemented");
}
}
void asmjit::simd_builder::_vec_binary_op(x86::Inst::Id sse_op, x86::Inst::Id vex_op, x86::Inst::Id evex_op, const Operand& dst, const Operand& lhs, const Operand& rhs)
{
if (utils::has_avx())
{
if (vex_op == x86::Inst::kIdNone || this->_extraReg.isReg())
{
this->evex().emit(evex_op, dst, lhs, rhs);
}
else
{
this->emit(vex_op, dst, lhs, rhs);
}
}
else if (dst == lhs)
{
this->emit(sse_op, dst, rhs);
}
else if (dst == rhs)
{
fmt::throw_exception("Unimplemented");
}
else
{
this->emit(x86::Inst::kIdMovaps, dst, lhs);
this->emit(sse_op, dst, rhs);
}
}
void asmjit::simd_builder::vec_umin(u32 esize, const Operand& dst, const Operand& lhs, const Operand& rhs)
{
using enum x86::Inst::Id;
if (esize == 2)
{
if (utils::has_sse41())
return _vec_binary_op(kIdPminuw, kIdVpminuw, kIdVpminuw, dst, lhs, rhs);
}
else if (esize == 4)
{
if (utils::has_sse41())
return _vec_binary_op(kIdPminud, kIdVpminud, kIdVpminud, dst, lhs, rhs);
}
fmt::throw_exception("Unimplemented");
}
void asmjit::simd_builder::vec_umax(u32 esize, const Operand& dst, const Operand& lhs, const Operand& rhs)
{
using enum x86::Inst::Id;
if (esize == 2)
{
if (utils::has_sse41())
return _vec_binary_op(kIdPmaxuw, kIdVpmaxuw, kIdVpmaxuw, dst, lhs, rhs);
}
else if (esize == 4)
{
if (utils::has_sse41())
return _vec_binary_op(kIdPmaxud, kIdVpmaxud, kIdVpmaxud, dst, lhs, rhs);
}
fmt::throw_exception("Unimplemented");
}
void asmjit::simd_builder::vec_umin_horizontal_i128(u32 esize, const x86::Gp& dst, const Operand& src, const Operand& tmp)
{
using enum x86::Inst::Id;
if (!utils::has_sse41())
{
fmt::throw_exception("Unimplemented");
}
ensure(src != tmp);
if (esize == 2)
{
this->emit(utils::has_avx() ? kIdVphminposuw : kIdPhminposuw, x86::Xmm(tmp.id()), x86::Xmm(src.id()));
this->emit(utils::has_avx() ? kIdVpextrw : kIdPextrw, dst, x86::Xmm(tmp.id()), Imm(0));
}
else if (esize == 4)
{
if (utils::has_avx())
{
this->vpsrldq(x86::Xmm(tmp.id()), x86::Xmm(src.id()), 8);
this->vpminud(x86::Xmm(tmp.id()), x86::Xmm(tmp.id()), x86::Xmm(src.id()));
this->vpsrldq(x86::Xmm(src.id()), x86::Xmm(tmp.id()), 4);
this->vpminud(x86::Xmm(src.id()), x86::Xmm(src.id()), x86::Xmm(tmp.id()));
this->vmovd(dst.r32(), x86::Xmm(src.id()));
}
else
{
this->movdqa(x86::Xmm(tmp.id()), x86::Xmm(src.id()));
this->psrldq(x86::Xmm(tmp.id()), 8);
this->pminud(x86::Xmm(tmp.id()), x86::Xmm(src.id()));
this->movdqa(x86::Xmm(src.id()), x86::Xmm(tmp.id()));
this->psrldq(x86::Xmm(src.id()), 4);
this->pminud(x86::Xmm(src.id()), x86::Xmm(tmp.id()));
this->movd(dst.r32(), x86::Xmm(src.id()));
}
}
else
{
fmt::throw_exception("Unimplemented");
}
}
void asmjit::simd_builder::vec_umax_horizontal_i128(u32 esize, const x86::Gp& dst, const Operand& src, const Operand& tmp)
{
using enum x86::Inst::Id;
if (!utils::has_sse41())
{
fmt::throw_exception("Unimplemented");
}
ensure(src != tmp);
if (esize == 2)
{
vec_set_all_ones(x86::Xmm(tmp.id()));
vec_xor(esize, x86::Xmm(tmp.id()), x86::Xmm(tmp.id()), x86::Xmm(src.id()));
this->emit(utils::has_avx() ? kIdVphminposuw : kIdPhminposuw, x86::Xmm(tmp.id()), x86::Xmm(tmp.id()));
this->emit(utils::has_avx() ? kIdVpextrw : kIdPextrw, dst, x86::Xmm(tmp.id()), Imm(0));
this->not_(dst.r16());
}
else if (esize == 4)
{
if (utils::has_avx())
{
this->vpsrldq(x86::Xmm(tmp.id()), x86::Xmm(src.id()), 8);
this->vpmaxud(x86::Xmm(tmp.id()), x86::Xmm(tmp.id()), x86::Xmm(src.id()));
this->vpsrldq(x86::Xmm(src.id()), x86::Xmm(tmp.id()), 4);
this->vpmaxud(x86::Xmm(src.id()), x86::Xmm(src.id()), x86::Xmm(tmp.id()));
this->vmovd(dst.r32(), x86::Xmm(src.id()));
}
else
{
this->movdqa(x86::Xmm(tmp.id()), x86::Xmm(src.id()));
this->psrldq(x86::Xmm(tmp.id()), 8);
this->pmaxud(x86::Xmm(tmp.id()), x86::Xmm(src.id()));
this->movdqa(x86::Xmm(src.id()), x86::Xmm(tmp.id()));
this->psrldq(x86::Xmm(src.id()), 4);
this->pmaxud(x86::Xmm(src.id()), x86::Xmm(tmp.id()));
this->movd(dst.r32(), x86::Xmm(src.id()));
}
}
else
{
fmt::throw_exception("Unimplemented");
}
}
#endif /* X86 */
#ifdef LLVM_AVAILABLE
#include <unordered_map>

View file

@ -51,6 +51,8 @@ using native_asm = asmjit::a64::Assembler;
using native_args = std::array<asmjit::a64::Gp, 4>;
#endif
union v128;
void jit_announce(uptr func, usz size, std::string_view name);
void jit_announce(auto* func, usz size, std::string_view name)
@ -211,40 +213,132 @@ namespace asmjit
}
#if defined(ARCH_X64)
template <uint Size>
struct native_vec;
template <>
struct native_vec<16> { using type = x86::Xmm; };
template <>
struct native_vec<32> { using type = x86::Ymm; };
template <>
struct native_vec<64> { using type = x86::Zmm; };
template <uint Size>
using native_vec_t = typename native_vec<Size>::type;
// if (count > step) { for (; ctr < (count - step); ctr += step) {...} count -= ctr; }
inline void build_incomplete_loop(native_asm& c, auto ctr, auto count, u32 step, auto&& build)
struct simd_builder : native_asm
{
asmjit::Label body = c.newLabel();
asmjit::Label exit = c.newLabel();
Operand v0, v1, v2, v3, v4, v5;
ensure((step & (step - 1)) == 0);
c.cmp(count, step);
c.jbe(exit);
c.sub(count, step);
c.align(asmjit::AlignMode::kCode, 16);
c.bind(body);
build();
c.add(ctr, step);
c.sub(count, step);
c.ja(body);
c.add(count, step);
c.bind(exit);
}
uint vsize = 16;
uint vmask = 0;
simd_builder(CodeHolder* ch) noexcept;
void _init(bool full);
void vec_cleanup_ret();
void vec_set_all_zeros(const Operand& v);
void vec_set_all_ones(const Operand& v);
void vec_set_const(const Operand& v, const v128& value);
void vec_clobbering_test(u32 esize, const Operand& v, const Operand& rhs);
// return x86::ptr(base, ctr, X, 0) where X is set for esize accordingly
x86::Mem ptr_scale_for_vec(u32 esize, const x86::Gp& base, const x86::Gp& index);
void vec_load_unaligned(u32 esize, const Operand& v, const x86::Mem& src);
void vec_store_unaligned(u32 esize, const Operand& v, const x86::Mem& dst);
void vec_partial_move(u32 esize, const Operand& dst, const Operand& src);
void _vec_binary_op(x86::Inst::Id sse_op, x86::Inst::Id vex_op, x86::Inst::Id evex_op, const Operand& dst, const Operand& lhs, const Operand& rhs);
void vec_shuffle_xi8(const Operand& dst, const Operand& lhs, const Operand& rhs)
{
using enum x86::Inst::Id;
_vec_binary_op(kIdPshufb, kIdVpshufb, kIdVpshufb, dst, lhs, rhs);
}
void vec_xor(u32, const Operand& dst, const Operand& lhs, const Operand& rhs)
{
using enum x86::Inst::Id;
_vec_binary_op(kIdPxor, kIdVpxor, kIdVpxord, dst, lhs, rhs);
}
void vec_or(u32, const Operand& dst, const Operand& lhs, const Operand& rhs)
{
using enum x86::Inst::Id;
_vec_binary_op(kIdPor, kIdVpor, kIdVpord, dst, lhs, rhs);
}
void vec_umin(u32 esize, const Operand& dst, const Operand& lhs, const Operand& rhs);
void vec_umax(u32 esize, const Operand& dst, const Operand& lhs, const Operand& rhs);
void vec_umin_horizontal_i128(u32 esize, const x86::Gp& dst, const Operand& src, const Operand& tmp);
void vec_umax_horizontal_i128(u32 esize, const x86::Gp& dst, const Operand& src, const Operand& tmp);
simd_builder& keep_if_not_masked()
{
if (vmask && vmask < 8)
{
this->k(x86::KReg(vmask));
}
return *this;
}
simd_builder& zero_if_not_masked()
{
if (vmask && vmask < 8)
{
this->k(x86::KReg(vmask));
this->z();
}
return *this;
}
void build_loop(u32 esize, auto reg_ctr, auto reg_cnt, auto&& build, auto&& reduce)
{
ensure((esize & (esize - 1)) == 0);
ensure(esize <= vsize);
Label body = this->newLabel();
Label next = this->newLabel();
Label exit = this->newLabel();
const u32 step = vsize / esize;
this->xor_(reg_ctr.r32(), reg_ctr.r32()); // Reset counter reg
this->sub(reg_cnt, step);
this->jb(next); // If count < step, skip main loop body
this->align(AlignMode::kCode, 16);
this->bind(body);
build();
this->add(reg_ctr, step);
this->sub(reg_cnt, step);
this->ja(body);
this->bind(next);
if (!vmask)
reduce();
this->add(reg_cnt, step);
this->jz(exit);
if (vmask)
{
// Build single last iteration (masked)
static constexpr u64 all_ones = -1;
this->bzhi(reg_cnt, x86::Mem(uptr(&all_ones)), reg_cnt);
this->kmovq(x86::k7, reg_cnt);
vmask = 7;
build();
vmask = -1;
reduce();
}
else
{
// Build tail loop (reduced vector width)
Label body = this->newLabel();
this->align(AlignMode::kCode, 16);
this->bind(body);
const uint vsz = vsize / step;
this->_init(false);
vsize = vsz;
build();
this->_init(true);
this->inc(reg_ctr);
this->sub(reg_cnt, 1);
this->ja(body);
}
this->bind(exit);
}
};
// for (; count > 0; ctr++, count--)
inline void build_loop(native_asm& c, auto ctr, auto count, auto&& build)
@ -262,6 +356,27 @@ namespace asmjit
c.ja(body);
c.bind(exit);
}
inline void maybe_flush_lbr(native_asm& c, uint count = 2)
{
// Workaround for bad LBR callstacks which happen in some situations (mainly TSX) - execute additional RETs
Label next = c.newLabel();
c.lea(x86::rcx, x86::qword_ptr(next));
for (u32 i = 0; i < count; i++)
{
c.push(x86::rcx);
c.sub(x86::rcx, 16);
}
for (u32 i = 0; i < count; i++)
{
c.ret();
c.align(asmjit::AlignMode::kCode, 16);
}
c.bind(next);
}
#endif
}