2026-01-28 06:16:04 +00:00

88 lines
2.6 KiB
C++

#pragma once
#include "gsl-lite.hpp"
#include <vector>
#include <algorithm>
#include <string>
#include <cstring>
#include <iostream>
#include "common/define.h"
#include "tensor.h"
namespace dynasty {
namespace common {
template<typename T>
void checkIndex(size_t idx, const std::vector<T>& vec){
if (idx >= vec.size()) {
std::cout << "idx: " << idx << " not fit with vector size: " << vec.size() ;
exit(-1);
}
}
enum class SimpleBroadcast : int32_t {
NoBroadcast = (int32_t)-1,
LeftScalar = (int32_t)-2,
RightScalar = (int32_t)-3,
RightPerChannelBatch1 = (int32_t)-4,
RightPerChannelBatchN = (int32_t)-5,
};
//--------------------------------------------------------------------
// Tensor Pitch
//--------------------------------------------------------------------
struct TensorPitches : std::vector<int32_t> {
TensorPitches(const TensorShape& shape, size_t rank = 0) : TensorPitches(shape.GetDims(), rank) {}
TensorPitches(const std::vector<int32_t>& dims, size_t rank = 0)
: std::vector<int32_t>(std::max(rank, dims.size()), 0) {
Calculate(gsl::span<int32_t>(data(), size()), dims);
}
static bool Calculate(gsl::span<int32_t> p, const std::vector<int32_t>& dims) {
// The pitches is the size of the next inner axis. Aka the amount to move by one of the next inner axis.
// For a tensor with shape(2,3,4,5) the values would be: (3*4*5, 4*5, 5, 1)
// Note that the outermost '2' is never used, as you never need to move by the entire size of the outermost axis
auto tensor_rank = dims.size();
auto pitch_rank = p.size();
auto padded_rank = pitch_rank - tensor_rank;
if (gsl::narrow_cast<ptrdiff_t>(padded_rank) < 0)
return false;
// Guard against Scalars
if (pitch_rank == 0) {
return true;
}
*(p.rbegin()) = 1; // The innermost axis is 1 (single values)
if (tensor_rank > 1) {
for (size_t i = tensor_rank - 1; i-- > 0;) {
p.operator[](i + padded_rank) = p.operator[](i + 1 + padded_rank) * dims[i + 1];
}
}
if (padded_rank >= 1) {
for (size_t i = 0; i < padded_rank; ++i) {
if (i == 0 && tensor_rank > 0) // For scalar tensor, the values in the pitches are all 1.
p.operator[](padded_rank - 1) = p.operator[](padded_rank) * dims[0];
else
p.operator[](padded_rank - 1 - i) = p.operator[](padded_rank - 1);
}
}
return true;
}
};
} // namespace common
} //