// Copyright 2023 Matthew Kolbe
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#if defined(HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_) == \
    defined(HWY_TARGET_TOGGLE)
#ifdef HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_
#undef HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_
#else
#define HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_
#endif

#include <cstdlib>  // std::abs

#include "hwy/highway.h"

HWY_BEFORE_NAMESPACE();
namespace hwy {
namespace HWY_NAMESPACE {

namespace hn = hwy::HWY_NAMESPACE;

template <class DERIVED, typename IN_T, typename OUT_T>
struct UnrollerUnit {
  static constexpr size_t kMaxTSize = HWY_MAX(sizeof(IN_T), sizeof(OUT_T));
  using LargerT = SignedFromSize<kMaxTSize>;  // only the size matters.

  DERIVED* me() { return static_cast<DERIVED*>(this); }

  static constexpr size_t MaxUnitLanes() {
    return HWY_MAX_LANES_D(hn::ScalableTag<LargerT>);
  }
  static size_t ActualLanes() { return Lanes(hn::ScalableTag<LargerT>()); }

  using LargerD = hn::CappedTag<LargerT, MaxUnitLanes()>;
  using IT = hn::Rebind<IN_T, LargerD>;
  using OT = hn::Rebind<OUT_T, LargerD>;
  IT d_in;
  OT d_out;
  using Y_VEC = hn::Vec<OT>;
  using X_VEC = hn::Vec<IT>;

  Y_VEC Func(const ptrdiff_t idx, const X_VEC x, const Y_VEC y) {
    return me()->Func(idx, x, y);
  }

  X_VEC X0Init() { return me()->X0InitImpl(); }

  X_VEC X0InitImpl() { return hn::Zero(d_in); }

  Y_VEC YInit() { return me()->YInitImpl(); }

  Y_VEC YInitImpl() { return hn::Zero(d_out); }

  X_VEC Load(const ptrdiff_t idx, const IN_T* from) {
    return me()->LoadImpl(idx, from);
  }

  X_VEC LoadImpl(const ptrdiff_t idx, const IN_T* from) {
    return hn::LoadU(d_in, from + idx);
  }

  // MaskLoad can take in either a positive or negative number for `places`. if
  // the number is positive, then it loads the top `places` values, and if it's
  // negative, it loads the bottom |places| values. example: places = 3
  //      | o | o | o | x | x | x | x | x |
  // example places = -3
  //      | x | x | x | x | x | o | o | o |
  X_VEC MaskLoad(const ptrdiff_t idx, const IN_T* from,
                 const ptrdiff_t places) {
    return me()->MaskLoadImpl(idx, from, places);
  }

  X_VEC MaskLoadImpl(const ptrdiff_t idx, const IN_T* from,
                     const ptrdiff_t places) {
    auto mask = hn::FirstN(d_in, static_cast<size_t>(places));
    auto maskneg = hn::Not(hn::FirstN(
        d_in,
        static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
    if (places < 0) mask = maskneg;

    return hn::MaskedLoad(mask, d_in, from + idx);
  }

  bool StoreAndShortCircuit(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) {
    return me()->StoreAndShortCircuitImpl(idx, to, x);
  }

  bool StoreAndShortCircuitImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) {
    hn::StoreU(x, d_out, to + idx);
    return true;
  }

  ptrdiff_t MaskStore(const ptrdiff_t idx, OUT_T* to, const Y_VEC x,
                      ptrdiff_t const places) {
    return me()->MaskStoreImpl(idx, to, x, places);
  }

  ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x,
                          const ptrdiff_t places) {
    auto mask = hn::FirstN(d_out, static_cast<size_t>(places));
    auto maskneg = hn::Not(hn::FirstN(
        d_out,
        static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
    if (places < 0) mask = maskneg;

    hn::BlendedStore(x, mask, d_out, to + idx);
    return std::abs(places);
  }

  ptrdiff_t Reduce(const Y_VEC x, OUT_T* to) { return me()->ReduceImpl(x, to); }

  ptrdiff_t ReduceImpl(const Y_VEC x, OUT_T* to) {
    // default does nothing
    (void)x;
    (void)to;
    return 0;
  }

  void Reduce(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) {
    me()->ReduceImpl(x0, x1, x2, y);
  }

  void ReduceImpl(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) {
    // default does nothing
    (void)x0;
    (void)x1;
    (void)x2;
    (void)y;
  }
};

template <class DERIVED, typename IN0_T, typename IN1_T, typename OUT_T>
struct UnrollerUnit2D {
  DERIVED* me() { return static_cast<DERIVED*>(this); }

  static constexpr size_t kMaxTSize =
      HWY_MAX(sizeof(IN0_T), HWY_MAX(sizeof(IN1_T), sizeof(OUT_T)));
  using LargerT = SignedFromSize<kMaxTSize>;  // only the size matters.

  static constexpr size_t MaxUnitLanes() {
    return HWY_MAX_LANES_D(hn::ScalableTag<LargerT>);
  }
  static size_t ActualLanes() { return Lanes(hn::ScalableTag<LargerT>()); }

  using LargerD = hn::CappedTag<LargerT, MaxUnitLanes()>;

  using I0T = hn::Rebind<IN0_T, LargerD>;
  using I1T = hn::Rebind<IN1_T, LargerD>;
  using OT = hn::Rebind<OUT_T, LargerD>;
  I0T d_in0;
  I1T d_in1;
  OT d_out;
  using Y_VEC = hn::Vec<OT>;
  using X0_VEC = hn::Vec<I0T>;
  using X1_VEC = hn::Vec<I1T>;

  hn::Vec<OT> Func(const ptrdiff_t idx, const hn::Vec<I0T> x0,
                   const hn::Vec<I1T> x1, const Y_VEC y) {
    return me()->Func(idx, x0, x1, y);
  }

  X0_VEC X0Init() { return me()->X0InitImpl(); }

  X0_VEC X0InitImpl() { return hn::Zero(d_in0); }

  X1_VEC X1Init() { return me()->X1InitImpl(); }

  X1_VEC X1InitImpl() { return hn::Zero(d_in1); }

  Y_VEC YInit() { return me()->YInitImpl(); }

  Y_VEC YInitImpl() { return hn::Zero(d_out); }

  X0_VEC Load0(const ptrdiff_t idx, const IN0_T* from) {
    return me()->Load0Impl(idx, from);
  }

  X0_VEC Load0Impl(const ptrdiff_t idx, const IN0_T* from) {
    return hn::LoadU(d_in0, from + idx);
  }

  X1_VEC Load1(const ptrdiff_t idx, const IN1_T* from) {
    return me()->Load1Impl(idx, from);
  }

  X1_VEC Load1Impl(const ptrdiff_t idx, const IN1_T* from) {
    return hn::LoadU(d_in1, from + idx);
  }

  // maskload can take in either a positive or negative number for `places`. if
  // the number is positive, then it loads the top `places` values, and if it's
  // negative, it loads the bottom |places| values. example: places = 3
  //      | o | o | o | x | x | x | x | x |
  // example places = -3
  //      | x | x | x | x | x | o | o | o |
  X0_VEC MaskLoad0(const ptrdiff_t idx, const IN0_T* from,
                   const ptrdiff_t places) {
    return me()->MaskLoad0Impl(idx, from, places);
  }

  X0_VEC MaskLoad0Impl(const ptrdiff_t idx, const IN0_T* from,
                       const ptrdiff_t places) {
    auto mask = hn::FirstN(d_in0, static_cast<size_t>(places));
    auto maskneg = hn::Not(hn::FirstN(
        d_in0,
        static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
    if (places < 0) mask = maskneg;

    return hn::MaskedLoad(mask, d_in0, from + idx);
  }

  hn::Vec<I1T> MaskLoad1(const ptrdiff_t idx, const IN1_T* from,
                         const ptrdiff_t places) {
    return me()->MaskLoad1Impl(idx, from, places);
  }

  hn::Vec<I1T> MaskLoad1Impl(const ptrdiff_t idx, const IN1_T* from,
                             const ptrdiff_t places) {
    auto mask = hn::FirstN(d_in1, static_cast<size_t>(places));
    auto maskneg = hn::Not(hn::FirstN(
        d_in1,
        static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
    if (places < 0) mask = maskneg;

    return hn::MaskedLoad(mask, d_in1, from + idx);
  }

  // store returns a bool that is `false` when
  bool StoreAndShortCircuit(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) {
    return me()->StoreAndShortCircuitImpl(idx, to, x);
  }

  bool StoreAndShortCircuitImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) {
    hn::StoreU(x, d_out, to + idx);
    return true;
  }

  ptrdiff_t MaskStore(const ptrdiff_t idx, OUT_T* to, const Y_VEC x,
                      const ptrdiff_t places) {
    return me()->MaskStoreImpl(idx, to, x, places);
  }

  ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x,
                          const ptrdiff_t places) {
    auto mask = hn::FirstN(d_out, static_cast<size_t>(places));
    auto maskneg = hn::Not(hn::FirstN(
        d_out,
        static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
    if (places < 0) mask = maskneg;

    hn::BlendedStore(x, mask, d_out, to + idx);
    return std::abs(places);
  }

  ptrdiff_t Reduce(const Y_VEC x, OUT_T* to) { return me()->ReduceImpl(x, to); }

  ptrdiff_t ReduceImpl(const Y_VEC x, OUT_T* to) {
    // default does nothing
    (void)x;
    (void)to;
    return 0;
  }

  void Reduce(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) {
    me()->ReduceImpl(x0, x1, x2, y);
  }

  void ReduceImpl(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) {
    // default does nothing
    (void)x0;
    (void)x1;
    (void)x2;
    (void)y;
  }
};

template <class FUNC, typename IN_T, typename OUT_T>
inline void Unroller(FUNC& f, const IN_T* HWY_RESTRICT x, OUT_T* HWY_RESTRICT y,
                     const ptrdiff_t n) {
  auto xx = f.X0Init();
  auto yy = f.YInit();
  ptrdiff_t i = 0;

#if HWY_MEM_OPS_MIGHT_FAULT
  constexpr auto lane_sz =
      static_cast<ptrdiff_t>(RemoveRef<FUNC>::MaxUnitLanes());
  if (n < lane_sz) {
    const DFromV<decltype(yy)> d;
    // this may not fit on the stack for HWY_RVV, but we do not reach this code
    // there
    HWY_ALIGN IN_T xtmp[static_cast<size_t>(lane_sz)];
    HWY_ALIGN OUT_T ytmp[static_cast<size_t>(lane_sz)];

    CopyBytes(x, xtmp, static_cast<size_t>(n) * sizeof(IN_T));
    xx = f.MaskLoad(0, xtmp, n);
    yy = f.Func(0, xx, yy);
    Store(Zero(d), d, ytmp);
    i += f.MaskStore(0, ytmp, yy, n);
    i += f.Reduce(yy, ytmp);
    CopyBytes(ytmp, y, static_cast<size_t>(i) * sizeof(OUT_T));
    return;
  }
#endif

  const ptrdiff_t actual_lanes =
      static_cast<ptrdiff_t>(RemoveRef<FUNC>::ActualLanes());
  if (n > 4 * actual_lanes) {
    auto xx1 = f.X0Init();
    auto yy1 = f.YInit();
    auto xx2 = f.X0Init();
    auto yy2 = f.YInit();
    auto xx3 = f.X0Init();
    auto yy3 = f.YInit();

    while (i + 4 * actual_lanes - 1 < n) {
      xx = f.Load(i, x);
      i += actual_lanes;
      xx1 = f.Load(i, x);
      i += actual_lanes;
      xx2 = f.Load(i, x);
      i += actual_lanes;
      xx3 = f.Load(i, x);
      i -= 3 * actual_lanes;

      yy = f.Func(i, xx, yy);
      yy1 = f.Func(i + actual_lanes, xx1, yy1);
      yy2 = f.Func(i + 2 * actual_lanes, xx2, yy2);
      yy3 = f.Func(i + 3 * actual_lanes, xx3, yy3);

      if (!f.StoreAndShortCircuit(i, y, yy)) return;
      i += actual_lanes;
      if (!f.StoreAndShortCircuit(i, y, yy1)) return;
      i += actual_lanes;
      if (!f.StoreAndShortCircuit(i, y, yy2)) return;
      i += actual_lanes;
      if (!f.StoreAndShortCircuit(i, y, yy3)) return;
      i += actual_lanes;
    }

    f.Reduce(yy3, yy2, yy1, &yy);
  }

  while (i + actual_lanes - 1 < n) {
    xx = f.Load(i, x);
    yy = f.Func(i, xx, yy);
    if (!f.StoreAndShortCircuit(i, y, yy)) return;
    i += actual_lanes;
  }

  if (i != n) {
    xx = f.MaskLoad(n - actual_lanes, x, i - n);
    yy = f.Func(n - actual_lanes, xx, yy);
    f.MaskStore(n - actual_lanes, y, yy, i - n);
  }

  f.Reduce(yy, y);
}

template <class FUNC, typename IN0_T, typename IN1_T, typename OUT_T>
inline void Unroller(FUNC& HWY_RESTRICT f, IN0_T* HWY_RESTRICT x0,
                     IN1_T* HWY_RESTRICT x1, OUT_T* HWY_RESTRICT y,
                     const ptrdiff_t n) {
  const ptrdiff_t lane_sz =
      static_cast<ptrdiff_t>(RemoveRef<FUNC>::ActualLanes());

  auto xx00 = f.X0Init();
  auto xx10 = f.X1Init();
  auto yy = f.YInit();

  ptrdiff_t i = 0;

#if HWY_MEM_OPS_MIGHT_FAULT
  if (n < lane_sz) {
    const DFromV<decltype(yy)> d;
    // this may not fit on the stack for HWY_RVV, but we do not reach this code
    // there
    constexpr auto max_lane_sz =
        static_cast<ptrdiff_t>(RemoveRef<FUNC>::MaxUnitLanes());
    HWY_ALIGN IN0_T xtmp0[static_cast<size_t>(max_lane_sz)];
    HWY_ALIGN IN1_T xtmp1[static_cast<size_t>(max_lane_sz)];
    HWY_ALIGN OUT_T ytmp[static_cast<size_t>(max_lane_sz)];

    CopyBytes(x0, xtmp0, static_cast<size_t>(n) * sizeof(IN0_T));
    CopyBytes(x1, xtmp1, static_cast<size_t>(n) * sizeof(IN1_T));
    xx00 = f.MaskLoad0(0, xtmp0, n);
    xx10 = f.MaskLoad1(0, xtmp1, n);
    yy = f.Func(0, xx00, xx10, yy);
    Store(Zero(d), d, ytmp);
    i += f.MaskStore(0, ytmp, yy, n);
    i += f.Reduce(yy, ytmp);
    CopyBytes(ytmp, y, static_cast<size_t>(i) * sizeof(OUT_T));
    return;
  }
#endif

  if (n > 4 * lane_sz) {
    auto xx01 = f.X0Init();
    auto xx11 = f.X1Init();
    auto yy1 = f.YInit();
    auto xx02 = f.X0Init();
    auto xx12 = f.X1Init();
    auto yy2 = f.YInit();
    auto xx03 = f.X0Init();
    auto xx13 = f.X1Init();
    auto yy3 = f.YInit();

    while (i + 4 * lane_sz - 1 < n) {
      xx00 = f.Load0(i, x0);
      xx10 = f.Load1(i, x1);
      i += lane_sz;
      xx01 = f.Load0(i, x0);
      xx11 = f.Load1(i, x1);
      i += lane_sz;
      xx02 = f.Load0(i, x0);
      xx12 = f.Load1(i, x1);
      i += lane_sz;
      xx03 = f.Load0(i, x0);
      xx13 = f.Load1(i, x1);
      i -= 3 * lane_sz;

      yy = f.Func(i, xx00, xx10, yy);
      yy1 = f.Func(i + lane_sz, xx01, xx11, yy1);
      yy2 = f.Func(i + 2 * lane_sz, xx02, xx12, yy2);
      yy3 = f.Func(i + 3 * lane_sz, xx03, xx13, yy3);

      if (!f.StoreAndShortCircuit(i, y, yy)) return;
      i += lane_sz;
      if (!f.StoreAndShortCircuit(i, y, yy1)) return;
      i += lane_sz;
      if (!f.StoreAndShortCircuit(i, y, yy2)) return;
      i += lane_sz;
      if (!f.StoreAndShortCircuit(i, y, yy3)) return;
      i += lane_sz;
    }

    f.Reduce(yy3, yy2, yy1, &yy);
  }

  while (i + lane_sz - 1 < n) {
    xx00 = f.Load0(i, x0);
    xx10 = f.Load1(i, x1);
    yy = f.Func(i, xx00, xx10, yy);
    if (!f.StoreAndShortCircuit(i, y, yy)) return;
    i += lane_sz;
  }

  if (i != n) {
    xx00 = f.MaskLoad0(n - lane_sz, x0, i - n);
    xx10 = f.MaskLoad1(n - lane_sz, x1, i - n);
    yy = f.Func(n - lane_sz, xx00, xx10, yy);
    f.MaskStore(n - lane_sz, y, yy, i - n);
  }

  f.Reduce(yy, y);
}

}  // namespace HWY_NAMESPACE
}  // namespace hwy
HWY_AFTER_NAMESPACE();

#endif  // HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_
