#pragma once #include "gsl-lite.hpp" #include #include #include #include #include #include "common/define.h" #include "tensor.h" namespace dynasty { namespace common { template void checkIndex(size_t idx, const std::vector& vec){ if (idx >= vec.size()) { std::cout << "idx: " << idx << " not fit with vector size: " << vec.size() ; exit(-1); } } enum class SimpleBroadcast : int32_t { NoBroadcast = (int32_t)-1, LeftScalar = (int32_t)-2, RightScalar = (int32_t)-3, RightPerChannelBatch1 = (int32_t)-4, RightPerChannelBatchN = (int32_t)-5, }; //-------------------------------------------------------------------- // Tensor Pitch //-------------------------------------------------------------------- struct TensorPitches : std::vector { TensorPitches(const TensorShape& shape, size_t rank = 0) : TensorPitches(shape.GetDims(), rank) {} TensorPitches(const std::vector& dims, size_t rank = 0) : std::vector(std::max(rank, dims.size()), 0) { Calculate(gsl::span(data(), size()), dims); } static bool Calculate(gsl::span p, const std::vector& dims) { // The pitches is the size of the next inner axis. Aka the amount to move by one of the next inner axis. // For a tensor with shape(2,3,4,5) the values would be: (3*4*5, 4*5, 5, 1) // Note that the outermost '2' is never used, as you never need to move by the entire size of the outermost axis auto tensor_rank = dims.size(); auto pitch_rank = p.size(); auto padded_rank = pitch_rank - tensor_rank; if (gsl::narrow_cast(padded_rank) < 0) return false; // Guard against Scalars if (pitch_rank == 0) { return true; } *(p.rbegin()) = 1; // The innermost axis is 1 (single values) if (tensor_rank > 1) { for (size_t i = tensor_rank - 1; i-- > 0;) { p.operator[](i + padded_rank) = p.operator[](i + 1 + padded_rank) * dims[i + 1]; } } if (padded_rank >= 1) { for (size_t i = 0; i < padded_rank; ++i) { if (i == 0 && tensor_rank > 0) // For scalar tensor, the values in the pitches are all 1. p.operator[](padded_rank - 1) = p.operator[](padded_rank) * dims[0]; else p.operator[](padded_rank - 1 - i) = p.operator[](padded_rank - 1); } } return true; } }; } // namespace common } //