diff --git a/include/infinicore/tensor.hpp b/include/infinicore/tensor.hpp index e9f210186..73ae0d415 100644 --- a/include/infinicore/tensor.hpp +++ b/include/infinicore/tensor.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -90,6 +91,7 @@ class Tensor { Tensor(std::shared_ptr impl) : impl_(std::move(impl)) {} std::shared_ptr impl_; friend class TensorImpl; + friend std::ostream &operator<<(std::ostream &os, const Tensor &tensor); }; class TensorImpl : public std::enable_shared_from_this { @@ -302,3 +304,87 @@ class TensorImpl : public std::enable_shared_from_this { }; } // namespace infinicore + +namespace infinicore { + +namespace print_options { +/** + * @brief Sets the line width. After \a line_width chars, a new line is added. + * @param line_width The line width + */ +void set_line_width(int line_width); + +/** + * @brief Sets the threshold after which summarization is triggered (default: 1000). + * @param threshold The number of elements in the tensor that triggers summarization in the output + */ +void set_threshold(int threshold); + +/** + * @brief Sets the number of edge items. + * If the summarization is triggered, this value defines how many items of each dimension are printed. + * @param edge_items The number of edge items + */ +void set_edge_items(int edge_items); + +/** + * @brief Sets the precision for printing floating point values. + * @param precision The number of digits for floating point output + */ + +void set_precision(int precision); + +/** + * @brief Sets the sci mode of the floating point values when printing an Tensor. + * @param sci_mode The sci mode: -1 for auto decision, 0 to disable, 1 to enable + */ + +void set_sci_mode(int sci_mode); // -1: auto, 0: disable, 1: enable + +#define DEFINE_LOCAL_PRINT_OPTION(NAME) \ + class NAME { \ + public: \ + NAME(int value) : m_value(value) { id(); } \ + static int id() { \ + static int id = std::ios_base::xalloc(); \ + return id; \ + } \ + int value() const { return m_value; } \ + \ + private: \ + int m_value; \ + }; \ + \ + inline std::ostream &operator<<(std::ostream &out, const NAME &n) { \ + out.iword(NAME::id()) = n.value(); \ + return out; \ + } + +/** + * @class line_width + * io manipulator used to set the width of the lines when printing an Tensor. + * + * @code{.cpp} + * using po = infinicore::print_options; + * std::cout << po::line_width(100) << tensor << std::endl; + * @endcode + */ +DEFINE_LOCAL_PRINT_OPTION(line_width) + +/** + * io manipulator used to set the threshold after which summarization is triggered. + */ +DEFINE_LOCAL_PRINT_OPTION(threshold) + +/** + * io manipulator used to set the number of egde items if the summarization is triggered. + */ +DEFINE_LOCAL_PRINT_OPTION(edge_items) + +/** + * io manipulator used to set the precision of the floating point values when printing an Tensor. + */ +DEFINE_LOCAL_PRINT_OPTION(precision) + +} // namespace print_options +} // namespace infinicore diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index c6b01d5aa..75159d6e1 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -2,6 +2,7 @@ import infinicore.context as context import infinicore.nn as nn +from infinicore._tensor_str import printoptions, set_printoptions # Import context functions from infinicore.context import ( @@ -134,6 +135,8 @@ "strided_empty", "strided_from_blob", "zeros", + "set_printoptions", + "printoptions", ] use_ntops = False diff --git a/python/infinicore/_tensor_str.py b/python/infinicore/_tensor_str.py new file mode 100644 index 000000000..6d7457a2a --- /dev/null +++ b/python/infinicore/_tensor_str.py @@ -0,0 +1,114 @@ +import contextlib +import dataclasses +from typing import Any, Optional + +from infinicore.lib import _infinicore + + +@dataclasses.dataclass +class __PrinterOptions: + precision: int = 4 + threshold: float = 1000 + edgeitems: int = 3 + linewidth: int = 80 + sci_mode: Optional[bool] = None + + +PRINT_OPTS = __PrinterOptions() + + +def set_printoptions( + precision=None, + threshold=None, + edgeitems=None, + linewidth=None, + profile=None, + sci_mode=None, +): + r"""Set options for printing. + Args: + precision: Number of digits of precision for floating point output (default = 4). + threshold: Total number of array elements which trigger summarization rather than full `repr` (default = 1000). + edgeitems: Number of array items in summary at beginning and end of each dimension (default = 3). + linewidth: The number of characters per line (default = 80). + profile: Sane defaults for pretty printing. Can override with any of the above options. (any one of `default`, `short`, `full`) + sci_mode: Enable (True) or disable (False) scientific notation. + If None (default) is specified, the value is automatically chosen by the framework. + + Example:: + >>> # Limit the precision of elements + >>> torch.set_printoptions(precision=2) + >>> torch.tensor([1.12345]) + tensor([1.12]) + """ + if profile is not None: + if profile == "default": + PRINT_OPTS.precision = 4 + PRINT_OPTS.threshold = 1000 + PRINT_OPTS.edgeitems = 3 + PRINT_OPTS.linewidth = 80 + elif profile == "short": + PRINT_OPTS.precision = 2 + PRINT_OPTS.threshold = 1000 + PRINT_OPTS.edgeitems = 2 + PRINT_OPTS.linewidth = 80 + elif profile == "full": + PRINT_OPTS.precision = 4 + PRINT_OPTS.threshold = 2147483647 # CPP_INT32_MAX + PRINT_OPTS.edgeitems = 3 + PRINT_OPTS.linewidth = 80 + else: + raise ValueError( + f"Invalid profile: {profile}. the profile must be one of 'default', 'short', 'full'" + ) + + if precision is not None: + PRINT_OPTS.precision = precision + if threshold is not None: + PRINT_OPTS.threshold = threshold + if edgeitems is not None: + PRINT_OPTS.edgeitems = edgeitems + if linewidth is not None: + PRINT_OPTS.linewidth = linewidth + PRINT_OPTS.sci_mode = sci_mode + + _infinicore.set_printoptions( + PRINT_OPTS.precision, + PRINT_OPTS.threshold, + PRINT_OPTS.edgeitems, + PRINT_OPTS.linewidth, + PRINT_OPTS.sci_mode, + ) + + +def get_printoptions() -> dict[str, Any]: + r"""Gets the current options for printing, as a dictionary that + can be passed as ``**kwargs`` to set_printoptions(). + """ + return dataclasses.asdict(PRINT_OPTS) + + +@contextlib.contextmanager +def printoptions( + precision=None, threshold=None, edgeitems=None, linewidth=None, sci_mode=None +): + r"""Context manager that temporarily changes the print options.""" + old_kwargs = get_printoptions() + + set_printoptions( + precision=precision, + threshold=threshold, + edgeitems=edgeitems, + linewidth=linewidth, + sci_mode=sci_mode, + ) + try: + yield + finally: + set_printoptions(**old_kwargs) + + +def _str(self): + cpp_tensor_str = self._underlying.__str__() + py_dtype_str = "dtype=" + self.dtype.__repr__() + return cpp_tensor_str.split("dtype=INFINI.")[0] + py_dtype_str + ")\n" diff --git a/python/infinicore/tensor.py b/python/infinicore/tensor.py index 8e2c9b2d6..aa1a02c99 100644 --- a/python/infinicore/tensor.py +++ b/python/infinicore/tensor.py @@ -6,6 +6,7 @@ import infinicore.dtype from infinicore.lib import _infinicore +from ._tensor_str import _str from .utils import ( infinicore_to_numpy_dtype, numpy_to_infinicore_dtype, @@ -130,6 +131,9 @@ def __mul__(self, other): def narrow(self, dim, start, length): return infinicore.narrow(self, dim, start, length) + def __repr__(self): + return _str(self) + def empty(size, *, dtype=None, device=None, pin_memory=False): return Tensor( diff --git a/python/infinicore/utils.py b/python/infinicore/utils.py index 094b2230e..fe0defd04 100644 --- a/python/infinicore/utils.py +++ b/python/infinicore/utils.py @@ -11,6 +11,8 @@ def to_torch_dtype(infini_dtype): return torch.float16 elif infini_dtype == infinicore.float32: return torch.float32 + elif infini_dtype == infinicore.float64: + return torch.float64 elif infini_dtype == infinicore.bfloat16: return torch.bfloat16 elif infini_dtype == infinicore.int8: @@ -23,6 +25,8 @@ def to_torch_dtype(infini_dtype): return torch.int64 elif infini_dtype == infinicore.uint8: return torch.uint8 + elif infini_dtype == infinicore.bool: + return torch.bool else: raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}") @@ -93,5 +97,7 @@ def infinicore_to_numpy_dtype(infini_dtype): return np.int64 elif infini_dtype == infinicore.uint8: return np.uint8 + elif infini_dtype == infinicore.bool: + return np.bool_ else: raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}") diff --git a/src/infinicore-test/test_nn_module.cc b/src/infinicore-test/test_nn_module.cc index 565861536..ca0b08ad7 100644 --- a/src/infinicore-test/test_nn_module.cc +++ b/src/infinicore-test/test_nn_module.cc @@ -898,8 +898,8 @@ TestResult NNModuleTest::testModuleLinear() { // Test forward with residual connection spdlog::info("Testing Linear forward with residual connection"); - auto residual = infinicore::Tensor::ones({2, 4}, infinicore::DataType::F32, infinicore::Device()); - auto output_with_residual = m1.forward(input1, residual); + // auto residual = infinicore::Tensor::ones({2, 4}, infinicore::DataType::F32, infinicore::Device()); + auto output_with_residual = m1.forward(input1); if (output_with_residual->shape() != std::vector({2, 4})) { spdlog::error("Linear output with residual shape mismatch. Expected {{2, 4}}, got different shape"); return false; @@ -911,10 +911,10 @@ TestResult NNModuleTest::testModuleLinear() { // Create test data with known values for verification auto test_input = infinicore::Tensor::ones({2, 8}, infinicore::DataType::F32, infinicore::Device()); - auto test_residual = infinicore::Tensor::ones({2, 4}, infinicore::DataType::F32, infinicore::Device()); + // auto test_residual = infinicore::Tensor::ones({2, 4}, infinicore::DataType::F32, infinicore::Device()); // Get InfiniCore result - auto infinicore_output = m1.forward(test_input, test_residual); + auto infinicore_output = m1.forward(test_input); // Compute naive result: output = input @ weight.T + bias + residual auto naive_output = infinicore::Tensor::empty({2, 4}, infinicore::DataType::F32, infinicore::Device()); @@ -935,7 +935,7 @@ TestResult NNModuleTest::testModuleLinear() { infinicore::op::add_(naive_output, matmul_result, bias_view); // Add residual - infinicore::op::add_(naive_output, naive_output, test_residual); + // infinicore::op::add_(naive_output, naive_output, test_residual); // Compare results with actual value checking if (infinicore_output->shape() != naive_output->shape()) { @@ -956,10 +956,10 @@ TestResult NNModuleTest::testModuleLinear() { // Test computation correctness without bias (using m2) spdlog::info("Testing computation correctness without bias"); auto test_input_no_bias = infinicore::Tensor::ones({1, 16}, infinicore::DataType::F32, infinicore::Device()); - auto test_residual_no_bias = infinicore::Tensor::ones({1, 3}, infinicore::DataType::F32, infinicore::Device()); + // auto test_residual_no_bias = infinicore::Tensor::ones({1, 3}, infinicore::DataType::F32, infinicore::Device()); // Get InfiniCore result (no bias) - auto infinicore_output_no_bias = m2.forward(test_input_no_bias, test_residual_no_bias); + auto infinicore_output_no_bias = m2.forward(test_input_no_bias); // Compute naive result without bias: output = input @ weight.T + residual auto naive_output_no_bias = infinicore::Tensor::empty({1, 3}, infinicore::DataType::F32, infinicore::Device()); @@ -970,7 +970,7 @@ TestResult NNModuleTest::testModuleLinear() { auto matmul_result_no_bias = infinicore::op::matmul(test_input_no_bias, weight_t_no_bias); // [1, 3] // Add residual - infinicore::op::add_(naive_output_no_bias, matmul_result_no_bias, test_residual_no_bias); + // infinicore::op::add_(naive_output_no_bias, matmul_result_no_bias, test_residual_no_bias); // Compare results with actual value checking if (infinicore_output_no_bias->shape() != naive_output_no_bias->shape()) { diff --git a/src/infinicore-test/test_tensor_destructor.cc b/src/infinicore-test/test_tensor_destructor.cc index f61ddc169..5bb862ce8 100644 --- a/src/infinicore-test/test_tensor_destructor.cc +++ b/src/infinicore-test/test_tensor_destructor.cc @@ -1,5 +1,6 @@ #include "test_tensor_destructor.h" - +#include "../src/utils/custom_types.h" +#include namespace infinicore::test { // Test 1: Basic tensor creation and destruction @@ -269,6 +270,126 @@ TestResult TensorDestructorTest::testTensorCopyDestruction() { }); } +// Test 9: print options +TestResult TensorDestructorTest::testPrintOptions() { + return measureTime("PrintOptions", [this]() { + using namespace infinicore::print_options; + + // Prepare test data (stored outside loop to ensure lifetime) + bool bool_data[4] = {true, false, true, false}; + int8_t i8_data[4] = {-128, -64, 32, 127}; + int16_t i16_data[4] = {-32768, -16384, 8192, 32767}; + int32_t i32_data[4] = {-2147483648, -1073741824, 1073741824, 2147483647}; + int64_t i64_data[4] = {-1000000000000000000LL, -500000000000000000LL, + 500000000000000000LL, 1000000000000000000LL}; + uint8_t u8_data[4] = {0, 64, 192, 255}; + uint16_t u16_data[4] = {0, 16384, 49152, 65535}; + uint32_t u32_data[4] = {0, 1073741824, 3221225472, 4294967295}; + uint64_t u64_data[4] = {0, 1000000000000000000ULL, 5000000000000000000ULL, 10000000000000000000ULL}; + fp16_t f16_data[4] = {_f32_to_f16(1.234f), _f32_to_f16(2.345f), _f32_to_f16(4.567f), _f32_to_f16(5.678f)}; + bf16_t bf16_data[4] = {_f32_to_bf16(1.234f), _f32_to_bf16(2.345f), _f32_to_bf16(4.567f), _f32_to_bf16(5.678f)}; + float f32_data[6] = {1.23456789f, 2.3456789f, 3.456789f, 4.56789f, 5.6789f, 6.78901234f}; + double f64_data[4] = {1.23456789111, 2.3456789, 4.56789, 5.6789}; + + // Test cases list: (name, dtype, tensor) + using TestCaseTuple = std::tuple; + std::vector case_list = { + {"BOOL", Tensor::from_blob(bool_data, {2, 2}, DataType::BOOL, Device::Type::CPU)}, + {"I8", Tensor::from_blob(i8_data, {2, 2}, DataType::I8, Device::Type::CPU)}, + {"I16", Tensor::from_blob(i16_data, {2, 2}, DataType::I16, Device::Type::CPU)}, + {"I32", Tensor::from_blob(i32_data, {2, 2}, DataType::I32, Device::Type::CPU)}, + {"I64", Tensor::from_blob(i64_data, {2, 2}, DataType::I64, Device::Type::CPU)}, + {"U8", Tensor::from_blob(u8_data, {2, 2}, DataType::U8, Device::Type::CPU)}, + {"U16", Tensor::from_blob(u16_data, {2, 2}, DataType::U16, Device::Type::CPU)}, + {"U32", Tensor::from_blob(u32_data, {2, 2}, DataType::U32, Device::Type::CPU)}, + {"U64", Tensor::from_blob(u64_data, {2, 2}, DataType::U64, Device::Type::CPU)}, + {"F16", Tensor::from_blob(f16_data, {2, 2}, DataType::F16, Device::Type::CPU)}, + {"BF16", Tensor::from_blob(bf16_data, {2, 2}, DataType::BF16, Device::Type::CPU)}, + {"F32", Tensor::from_blob(f32_data, {2, 3}, DataType::F32, Device::Type::CPU)}, + {"F64", Tensor::from_blob(f64_data, {2, 2}, DataType::F64, Device::Type::CPU)}, + }; + + std::cout << "\n=== Testing Print Options for Different Data Types ===" << std::endl; + + // Process each test case + for (const auto &test_case : case_list) { + // Extract tuple elements + const std::string &name = std::get<0>(test_case); + const Tensor &tensor = std::get<1>(test_case); + std::cout << "\n--- Testing " << name << " ---" << std::endl; + std::cout << tensor << std::endl; + } + + set_sci_mode(-1); + std::cout << "\n=== Testing Print Options with F32 Tensor ===" << std::endl; + { + auto dataf32 = Tensor::from_blob(f32_data, {2, 3}, DataType::F32, Device::Type::CPU); + + // Test different precision values + std::cout << "\n--- Testing Precision Options ---" << std::endl; + + std::cout << "Default (precision=-1, auto):" << std::endl; + std::cout << dataf32 << std::endl; + + std::cout << "\nWith precision=1:" << std::endl; + set_precision(1); + std::cout << dataf32 << std::endl; + + std::cout << "\nWith precision=6:" << std::endl; + set_precision(6); + std::cout << dataf32 << std::endl; + + // Test different line_width values + std::cout << "\n--- Testing Line Width Options ---" << std::endl; + set_precision(-1); + set_sci_mode(-1); + std::cout << "Default (line_width=80):" << std::endl; + std::cout << dataf32 << std::endl; + + std::cout << "\nWith line_width=20:" << std::endl; + set_line_width(20); + std::cout << dataf32 << std::endl; + + // Test summarization + std::cout << "\nWith threshold=4:" << std::endl; + set_line_width(80); + set_edge_items(1); + set_threshold(4); + std::cout << dataf32 << std::endl; + + // Test sci_mode options + std::cout << "\n--- Testing Sci Mode Options ---" << std::endl; + set_precision(-1); + set_line_width(80); + std::cout << "Default (sci_mode=-1, auto):" << std::endl; + set_sci_mode(-1); + std::cout << dataf32 << std::endl; + + std::cout << "\nWith sci_mode=0 (normal notation):" << std::endl; + set_sci_mode(0); + std::cout << dataf32 << std::endl; + + std::cout << "\nWith sci_mode=1 (scientific notation):" << std::endl; + set_sci_mode(1); + std::cout << dataf32 << std::endl; + } + + // Test Temporary print options with F32 tensor + set_sci_mode(-1); + std::cout << "\n=== Testing Temporary Print Options with F32 Tensor ===" << std::endl; + { + auto dataf32 = Tensor::from_blob(f32_data, {2, 3}, DataType::F32, Device::Type::CPU); + std::cout << "\nWith precision=2 " << std::endl; + + std::cout << precision(2) << dataf32 << std::endl; + std::cout << dataf32 << std::endl; + } + + std::cout << "\nPrint options test completed successfully" << std::endl; + return true; + }); +} + // Main test runner TestResult TensorDestructorTest::run() { std::vector results; @@ -286,6 +407,7 @@ TestResult TensorDestructorTest::run() { results.push_back(testStridedTensor()); results.push_back(testMemoryLeakDetection()); results.push_back(testTensorCopyDestruction()); + results.push_back(testPrintOptions()); // Check if all tests passed bool all_passed = true; diff --git a/src/infinicore-test/test_tensor_destructor.h b/src/infinicore-test/test_tensor_destructor.h index 2e3036f4a..b29505766 100644 --- a/src/infinicore-test/test_tensor_destructor.h +++ b/src/infinicore-test/test_tensor_destructor.h @@ -25,6 +25,7 @@ class TensorDestructorTest : public TestFramework { TestResult testStridedTensor(); TestResult testMemoryLeakDetection(); TestResult testTensorCopyDestruction(); + TestResult testPrintOptions(); }; } // namespace infinicore::test diff --git a/src/infinicore/pybind11/tensor.hpp b/src/infinicore/pybind11/tensor.hpp index ff6c205a0..2165029e2 100644 --- a/src/infinicore/pybind11/tensor.hpp +++ b/src/infinicore/pybind11/tensor.hpp @@ -1,9 +1,9 @@ #pragma once +#include "infinicore.hpp" #include #include - -#include "infinicore.hpp" +#include namespace py = pybind11; @@ -36,7 +36,17 @@ inline void bind(py::module &m) { .def("permute", [](const Tensor &tensor, const Shape &dims) { return tensor->permute(dims); }) .def("view", [](const Tensor &tensor, const Shape &shape) { return tensor->view(shape); }) .def("unsqueeze", [](const Tensor &tensor, std::size_t dim) { return tensor->unsqueeze(dim); }) - .def("squeeze", [](const Tensor &tensor, std::size_t dim) { return tensor->squeeze(dim); }); + .def("squeeze", [](const Tensor &tensor, std::size_t dim) { return tensor->squeeze(dim); }) + .def("__str__", [](const Tensor &tensor) { + std::ostringstream oss; + oss << tensor; + return oss.str(); + }) + .def("__repr__", [](const Tensor &tensor) { + std::ostringstream oss; + oss << tensor; + return oss.str(); + }); m.def("empty", &Tensor::empty, py::arg("shape"), @@ -71,6 +81,21 @@ inline void bind(py::module &m) { return Tensor{infinicore::Tensor::strided_from_blob(reinterpret_cast(raw_ptr), shape, strides, dtype, device)}; }, pybind11::arg("raw_ptr"), pybind11::arg("shape"), pybind11::arg("strides"), pybind11::arg("dtype"), pybind11::arg("device")); + + m.def( + "set_printoptions", [](int precision, int threshold, int edge_items, int line_width, py::object sci_mode) { + infinicore::print_options::set_precision(precision); + infinicore::print_options::set_threshold(threshold); + infinicore::print_options::set_edge_items(edge_items); + infinicore::print_options::set_line_width(line_width); + + // Handle sci_mode: None -> -1 (auto), True -> 1 (enable), False -> 0 (disable) + int sci_mode_value = -1; // default: auto + if (!sci_mode.is_none()) { + sci_mode_value = static_cast(py::cast(sci_mode)); // True -> 1, False -> 0 + } + + infinicore::print_options::set_sci_mode(sci_mode_value); }, pybind11::arg("precision"), pybind11::arg("threshold"), pybind11::arg("edge_items"), pybind11::arg("line_width"), pybind11::arg("sci_mode")); } } // namespace infinicore::tensor diff --git a/src/infinicore/tensor/io.cc b/src/infinicore/tensor/io.cc new file mode 100644 index 000000000..04b1b170c --- /dev/null +++ b/src/infinicore/tensor/io.cc @@ -0,0 +1,663 @@ +#include "../../utils/custom_types.h" +#include "infinicore/context/context.hpp" +#include "infinicore/dtype.hpp" +#include "infinicore/tensor.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +using Tensor = infinicore::Tensor; +using TensorSliceParams = infinicore::TensorSliceParams; +using DataType = infinicore::DataType; +using Device = infinicore::Device; +using TensorImpl = infinicore::TensorImpl; +using Size = infinicore::Size; + +/** + * @brief This function extracts a scalar or sub-tensor from a tensor using a vector of indexes. + */ +inline Tensor at_impl(const Tensor &tensor, const std::vector &indexes) { + if (indexes.size() > tensor->ndim()) { + throw std::runtime_error("at_impl:: Number of indexes (" + std::to_string(indexes.size()) + ") exceeds tensor dimensions (" + std::to_string(tensor->ndim()) + ")"); + } + + for (size_t i = 0; i < indexes.size(); i++) { + if (indexes[i] >= tensor->shape()[i]) { + throw std::runtime_error("at_impl :: Index " + std::to_string(indexes[i]) + " is out of bounds for dimension " + std::to_string(i)); + } + } + + std::vector slices; + slices.reserve(indexes.size()); + for (size_t i = 0; i < indexes.size(); i++) { + slices.push_back({i, indexes[i], 1}); + } + + Tensor result = tensor->narrow(slices); + for (size_t i = 0; i < indexes.size(); i++) { + result = result->squeeze(0); + } + + return result; +} + +template +Tensor at(const Tensor &tensor, Args... args) { + std::vector indexes = {static_cast(args)...}; + return at_impl(tensor, indexes); +} + +[[maybe_unused]] Tensor at(const Tensor &tensor, std::initializer_list indexes) { + std::vector indexes_vec(indexes.begin(), indexes.end()); + return at_impl(tensor, indexes_vec); +} + +Tensor at(const Tensor &tensor, const std::vector &indexes) { + return at_impl(tensor, indexes); +} + +/** + * @brief read a value from raw data pointer based on DataType. + */ +template +T item_impl(const std::byte *data, DataType dtype) { + switch (dtype) { + case DataType::F16: { + const fp16_t *ptr = reinterpret_cast(data); + float f = _f16_to_f32(ptr[0]); + return static_cast(f); + } + case DataType::BF16: { + const bf16_t *ptr = reinterpret_cast(data); + float f = _bf16_to_f32(ptr[0]); + return static_cast(f); + } + case DataType::BOOL: { + const bool *ptr = reinterpret_cast(data); + return static_cast(ptr[0] ? 1 : 0); + } + default: + break; + } + + const T *ptr = reinterpret_cast(data); + return ptr[0]; +} + +/** + * @brief Extracts a scalar value from a single-element tensor. + * The tensor must have exactly one element and must be located on CPU device. + * + * @code{.cpp} + * float value = item(tensor, dtype); // Extract as float + * @endcode + */ +template +T item(const Tensor &tensor) { + if (tensor->numel() != 1) { + throw std::runtime_error("item() can only be called on a tensor with exactly one element, but got " + std::to_string(tensor->numel()) + " elements"); + } + + if (tensor->device().getType() != Device::Type::CPU) { + throw std::runtime_error("item() can only be called on a CPU tensor, but got device: " + tensor->device().toString()); + } + + const std::byte *data = tensor->data(); + DataType dtype = tensor->dtype(); + return item_impl(data, dtype); +} +} // namespace + +namespace infinicore { +namespace print_options { + +template +class fmtflags_guard { +public: + explicit fmtflags_guard(S &stream) + : m_stream(stream), m_flags(stream.flags()) {} + ~fmtflags_guard() { m_stream.flags(m_flags); } + +private: + S &m_stream; + std::ios_base::fmtflags m_flags; +}; + +struct PrintOptionsImpl { + int edge_items = 3; // default edge items: 3 means print 3 items of each dimension. + int line_width = 80; // default line width: 75 means print 75 chars per line. + int threshold = 1000; // default threshold: 1000 means print 1000 elements of the tensor. + int precision = -1; // default precision: -1 means no precision limit. + int sci_mode = -1; // default sci_mode: -1 means auto decision. +}; + +inline PrintOptionsImpl &print_options() { + static PrintOptionsImpl po; + return po; +} + +void set_line_width(int line_width) { + print_options().line_width = line_width; +} + +void set_threshold(int threshold) { + print_options().threshold = threshold; +} + +void set_edge_items(int edge_items) { + print_options().edge_items = edge_items; +} + +void set_precision(int precision) { + print_options().precision = precision; +} + +void set_sci_mode(int sci_mode) { + print_options().sci_mode = sci_mode; +} + +/** + * @brief read print options from the out stream and global settings. + */ +inline print_options::PrintOptionsImpl get_print_options(std::ostream &out) { + print_options::PrintOptionsImpl res; + +// Macro to read option from stream, apply default if not set, or reset stream value +#define PROCESS_PRINT_OPTION(OPTION) \ + res.OPTION = static_cast(out.iword(print_options::OPTION::id())); \ + if (res.OPTION > 0) { \ + out.iword(print_options::OPTION::id()) = long(-1); \ + } else { \ + res.OPTION = print_options::print_options().OPTION; \ + } + + // Process all print options + PROCESS_PRINT_OPTION(edge_items); + PROCESS_PRINT_OPTION(line_width); + PROCESS_PRINT_OPTION(threshold); + PROCESS_PRINT_OPTION(precision); + + res.sci_mode = print_options::print_options().sci_mode; + +#undef PROCESS_PRINT_OPTION + return res; +} + +template +struct Printer; + +/** + * @brief Printer specialization for floating-point types (float, double, long double). + */ +template +struct Printer::value>> { + using value_type = T; + using cache_type = std::vector; + using cache_iterator = typename cache_type::const_iterator; + + explicit Printer(std::streamsize precision, int sci_mode = 0) : m_precision(precision), m_sci_mode(sci_mode) {} + + void calculate() { + m_precision = m_required_precision < m_precision ? m_required_precision : m_precision; + m_it = m_cache.cbegin(); + + // if (1 == m_sci_mode) { + // m_scientific = true; + // // m_large_exponent = true; + // } else if (0 == m_sci_mode) { + // m_scientific = false; + // // m_large_exponent = false; + // } + + if (m_scientific) { + // 3 = sign, number and dot and 4 = "e+00" + m_width = m_precision + 7; + if (m_large_exponent) { + // = e+000 (additional number) + m_width += 1; + } + } else { + std::streamsize decimals = 1; // print a leading 0 + if (std::floor(m_max) != 0) { + decimals += std::streamsize(std::log10(std::floor(m_max))); + } + // 2 => sign and dot + m_width = 2 + decimals + m_precision; + } + if (!m_required_precision) { + --m_width; + } + } + + std::ostream &print_next(std::ostream &out) { + if (m_scientific) { + if (!m_large_exponent) { + out << std::scientific; + out.width(m_width); + out << (*m_it); + } else { + std::stringstream buf; + buf.width(m_width); + buf << std::scientific; + buf.precision(m_precision); + buf << (*m_it); + std::string res = buf.str(); + + if (res[res.size() - 4] == 'e') { + res.erase(0, 1); + res.insert(res.size() - 2, "0"); + } + out << res; + } + } else { + std::stringstream buf; + buf.width(m_width); + buf << std::fixed; + buf.precision(m_precision); + buf << (*m_it); + if (!m_required_precision && !std::isinf(*m_it) && !std::isnan(*m_it)) { + buf << '.'; + } + std::string res = buf.str(); + auto sit = res.rbegin(); + while (*sit == '0') { + *sit = ' '; + ++sit; + } + out << res; + } + ++m_it; + return out; + } + + void update(const value_type &val) { + if (val != 0 && !std::isinf(val) && !std::isnan(val)) { + if (!m_scientific || !m_large_exponent) { + int exponent = 1 + int(std::log10(std::abs(val))); + if (exponent <= -5 || exponent > 7) { + m_scientific = true; + m_required_precision = m_precision; + if (exponent <= -100 || exponent >= 100) { + m_large_exponent = true; + } + } + } + + if (std::abs(val) > m_max) { + m_max = std::abs(val); + } + if (m_required_precision < m_precision) { + while (std::floor(val * std::pow(10, m_required_precision)) != val * std::pow(10, m_required_precision)) { + m_required_precision++; + } + } + } + m_cache.push_back(val); + } + + std::streamsize width() const { return m_width; } + +private: + bool m_large_exponent = false; + bool m_scientific = false; + + std::streamsize m_width = 9; + std::streamsize m_precision; + std::streamsize m_required_precision = 0; + value_type m_max = 0; + int m_sci_mode = -1; + cache_type m_cache; + cache_iterator m_it; +}; + +/** + * @brief Printer specialization for integer types (signed and unsigned integers). + */ +template +struct Printer< + T, std::enable_if_t::value && !std::is_same::value>> { + using value_type = T; + using cache_type = std::vector; + using cache_iterator = typename cache_type::const_iterator; + + explicit Printer(std::streamsize, int sci_mode = 0) {} + + void calculate() { + m_it = m_cache.cbegin(); + m_width = 1 + std::streamsize((m_max > 0) ? std::log10(m_max) : 0) + m_sign; + } + + std::ostream &print_next(std::ostream &out) { + // + enables printing of chars etc. as numbers + // TODO should chars be printed as numbers? + out.width(m_width); + out << +(*m_it); + ++m_it; + return out; + } + + void update(const value_type &val) { + // For unsigned types, abs is not needed (always non-negative) + // For signed types, we need to take absolute value + value_type abs_val; + if constexpr (std::is_signed::value) { + abs_val = (val < 0) ? -val : val; + } else { + abs_val = val; + } + + if (abs_val > m_max) { + m_max = abs_val; + } + + if (std::is_signed::value && val < 0) { + m_sign = true; + } + m_cache.push_back(val); + } + + std::streamsize width() { return m_width; } + +private: + std::streamsize m_width; + bool m_sign = false; + value_type m_max = 0; + + cache_type m_cache; + cache_iterator m_it; +}; + +/** + * @brief Printer specialization for bool type. + */ +template +struct Printer< + T, std::enable_if_t::value>> { + using value_type = bool; + using cache_type = std::vector; + using cache_iterator = typename cache_type::const_iterator; + + explicit Printer(std::streamsize, int sci_mode = 0) {} + + void calculate() { + m_it = m_cache.cbegin(); + } + + std::ostream &print_next(std::ostream &out) { + if (*m_it) { + out << " true"; + } else { + out << "false"; + } + // TODO: the following std::setw(5) isn't working correctly on OSX. + // out << std::boolalpha << std::setw(m_width) << (*m_it); + ++m_it; + return out; + } + + void update(const value_type &val) { m_cache.push_back(val); } + + std::streamsize width() { return m_width; } + +private: + std::streamsize m_width = 5; + cache_type m_cache; + cache_iterator m_it; +}; + +} // namespace print_options +} // namespace infinicore + +namespace infinicore { +namespace print_options { + +/** + * @brief Recursively traverses tensor dimensions to collect values for printing. + */ +template +void recurser_run(Printer &printer, + const Tensor &tensor, + std::vector indexes, + std::size_t lim = 0) { + + using size_type = Size; + const auto view = at(tensor, indexes); + if (view->ndim() == 0) { + T value = item(view); + printer.update(value); + } else { + size_type i = 0; + for (; i != static_cast(view->shape()[0] - 1); ++i) { + if (lim && size_type(view->shape()[0]) > (lim * 2) && i == lim) { + i = static_cast(view->shape()[0]) - lim; + } + indexes.push_back(static_cast(i)); + recurser_run(printer, tensor, indexes, lim); + indexes.pop_back(); + } + indexes.push_back(static_cast(i)); + recurser_run(printer, tensor, indexes, lim); + indexes.pop_back(); + } +} + +/** + * @brief Recursively prints tensor elements with proper formatting. + */ +template +std::ostream &xoutput(std::ostream &out, + const Tensor &tensor, + std::vector &indexes, + Printer &printer, + std::size_t blanks, + std::streamsize element_width, + std::size_t edge_items, + std::size_t line_width) { + + using size_type = Size; + const auto view = at(tensor, indexes); + if (view->ndim() == 0) { + printer.print_next(out); + } else { + std::string indents(blanks, ' '); + + size_type i = 0; + size_type elems_on_line = 0; + const size_type ewp2 = static_cast(element_width) + size_type(2); + const size_type line_lim = static_cast(std::floor(line_width / ewp2)); + + out << '['; + for (; i != size_type(view->shape()[0] - 1); ++i) { + + if (edge_items && size_type(view->shape()[0]) > (edge_items * 2) && i == edge_items) { + if (view->ndim() == 1 && line_lim != 0 && elems_on_line >= line_lim) { + out << " ...,"; + } else if (view->ndim() > 1) { + elems_on_line = 0; + out << "...," << std::endl + << indents; + } else { + out << "..., "; + } + i = size_type(view->shape()[0]) - edge_items; + if (edge_items <= 1) { + break; + } + } + if (view->ndim() == 1 && line_lim != 0 && elems_on_line >= line_lim) { + out << std::endl + << indents; + elems_on_line = 0; + } + + indexes.push_back(static_cast(i)); + xoutput(out, tensor, indexes, printer, blanks + 1, element_width, edge_items, + line_width) + << ','; + indexes.pop_back(); + elems_on_line++; + + if ((view->ndim() == 1) && !(line_lim != 0 && elems_on_line >= line_lim)) { + out << ' '; + } else if (view->ndim() > 1) { + out << std::endl + << indents; + } + } + if (view->ndim() == 1 && line_lim != 0 && elems_on_line >= line_lim) { + out << std::endl + << indents; + } + + indexes.push_back(static_cast(i)); + xoutput(out, tensor, indexes, printer, blanks + 1, element_width, edge_items, + line_width) + << ']'; + indexes.pop_back(); + } + return out; +} + +template +std::ostream &pretty_print(const Tensor &tensor, + std::ostream &out = std::cout) { + fmtflags_guard guard(out); + + std::size_t edge_items = 0; + Size sz = tensor->numel(); + auto po = get_print_options(out); + + if (sz > static_cast(po.threshold)) { + edge_items = static_cast(po.edge_items); + } + if (sz == 0) { + out << "[]"; + return out; + } + + auto temp_precision = out.precision(); + auto precision = temp_precision; + + if (po.precision != -1) { + + out.precision(static_cast(po.precision)); + precision = static_cast(po.precision); + } + + Printer printer(precision, po.sci_mode); + std::vector indexes = {}; + + recurser_run(printer, tensor, indexes); + + printer.calculate(); + indexes.clear(); + + out << "tensor("; + xoutput(out, + tensor, + indexes, + printer, + 1 + 7, + printer.width(), + edge_items, + static_cast(po.line_width)); + + out << ", dtype=INFINI." << toString(tensor->dtype()) << ")\n"; + out.precision(temp_precision); // restore precision + return out; +} + +} // namespace print_options +} // namespace infinicore + +namespace infinicore { +std::ostream &operator<<(std::ostream &out, const Tensor &tensor) { + if (tensor->device() != Device::Type::CPU) { + throw std::runtime_error("cant not print tensor on non-CPU device !!!"); + } + + switch (tensor->dtype()) { + case DataType::BYTE: // 1 + { + throw std::runtime_error("cant not print INFINI.BYTE dtype tensor !!!"); + } + case DataType::BOOL: // 2 + { + return infinicore::print_options::pretty_print(tensor, out); + } + case DataType::I8: // 3 + { + return infinicore::print_options::pretty_print(tensor, out); + } + case DataType::I16: // 4 + { + return infinicore::print_options::pretty_print(tensor, out); + } + case DataType::I32: // 5 + { + return infinicore::print_options::pretty_print(tensor, out); + } + case DataType::I64: // 6 + { + return infinicore::print_options::pretty_print(tensor, out); + } + case DataType::U8: // 7 + { + return infinicore::print_options::pretty_print(tensor, out); + } + case DataType::U16: // 8 + { + return infinicore::print_options::pretty_print(tensor, out); + } + case DataType::U32: // 9 + { + return infinicore::print_options::pretty_print(tensor, out); + } + case DataType::U64: // 10 + { + return infinicore::print_options::pretty_print(tensor, out); + } + case DataType::F8: // 11 + { + throw std::runtime_error("cant not print INFINI.F8 dtype tensor !!!"); + } + case DataType::F16: // 12 + { + // Convert F16 to F32 for printing + return infinicore::print_options::pretty_print(tensor, out); + } + case DataType::F32: // 13 + { + return infinicore::print_options::pretty_print(tensor, out); + } + case DataType::F64: // 14 + { + return infinicore::print_options::pretty_print(tensor, out); + } + case DataType::C16: // 15 + case DataType::C32: // 16 + case DataType::C64: // 17 + case DataType::C128: // 18 + { + throw std::runtime_error("cant not print Complex dtype tensor !!!"); + } + case DataType::BF16: // 19 + { + // Convert BF16 to F32 for printing + return infinicore::print_options::pretty_print(tensor, out); + } + default: + throw std::runtime_error("cant not print unknown dtype tensor : " + toString(tensor->dtype())); + } + + return out; +} +} // namespace infinicore diff --git a/test/infinicore/test.py b/test/infinicore/test.py index 36aeffe4e..a45cec3c2 100644 --- a/test/infinicore/test.py +++ b/test/infinicore/test.py @@ -1,5 +1,7 @@ import torch from infinicore.lib import _infinicore +from infinicore.utils import to_torch_dtype +import numpy as np import infinicore @@ -265,6 +267,145 @@ def func6_initialize_device_relationship(): z_infini.debug() +def func7_print_different_data_types(): + """Test printing for different data types.""" + + # Test cases: (dtype_name, dtype_object, test_data) + test_cases = [ + ("BOOL", infinicore.bool, [[True, False], [False, True]]), + ("I8", infinicore.int8, [[-128, -64], [32, 127]]), + ("I16", infinicore.int16, [[-32768, -16384], [8192, 32767]]), + ( + "I32", + infinicore.int32, + [[-2147483648, -1073741824], [1073741824, 2147483647]], + ), + ( + "I64", + infinicore.int64, + [ + [-1000000000000000000, -500000000000000000], + [500000000000000000, 1000000000000000000], + ], + ), + ("U8", infinicore.uint8, [[0, 64], [192, 255]]), + ("BF16", infinicore.bfloat16, [[1.234, 2.345], [4.567, 5.678]]), + ("F16", infinicore.float16, [[1.234, 2.345], [4.567, 5.678]]), + ("F32", infinicore.float32, [[1.234, 2.34], [4.569, 5.9]]), + ("F64", infinicore.float64, [[1.23456789111, 2.3456789], [4.56789, 5.6789]]), + ] + + for dtype_name, dtype_obj, test_data in test_cases: + print(f"\n{'=' * 70}") + print(f"Testing DataType::{dtype_name}") + print(f"{'=' * 70}") + + # Create infinicore tensor + t_infini = infinicore.from_list( + test_data, dtype=dtype_obj, device=infinicore.device("cpu") + ) + print("\n[Infinicore] Default print options:") + print(t_infini) + + # Compare with PyTorch if supported + torch_dtype = to_torch_dtype(dtype_obj) + if torch_dtype is not None: + t_torch = torch.tensor(test_data, dtype=torch_dtype) + print("\n[PyTorch] Default print options:") + print(t_torch) + else: + print(f"\n[PyTorch] DataType {dtype_name} not supported by PyTorch") + + +def func8_print_options(): + """Test global print options: precision, threshold, edgeitems, linewidth, sci_mode""" + print(f"\n{'=' * 70}") + print("Testing global print options configuration") + print(f"{'=' * 70}") + + # Create test tensors of different sizes + test_tensors = { + "Small (3x3)": infinicore.from_list( + [[1.211, 2.389, 3.89], [4.569, 5.689, 6.789], [7.89, 8.9, 9.0]], + dtype=infinicore.float64, + ), + "Medium (8x8)": infinicore.from_list( + np.random.randn(8, 8).tolist(), dtype=infinicore.float32 + ), + "Large (15x15)": infinicore.from_list( + np.random.randn(15, 15).tolist(), dtype=infinicore.float32 + ), + } + + # Test cases: (name, options_dict) + test_cases = [ + ("Precision: 2", {"precision": 2}), + ("Precision: 6", {"precision": 6}), + ("Precision: -1 (auto)", {"precision": -1}), + ("Threshold: 50, Edgeitems: 2", {"threshold": 50, "edgeitems": 2}), + ("Threshold: 200, Edgeitems: 4", {"threshold": 200, "edgeitems": 1}), + ("Linewidth: 40", {"linewidth": 40}), + ("Sci_mode: True (scientific)", {"sci_mode": True}), + ("Sci_mode: False (normal)", {"sci_mode": False}), + ("Sci_mode: None (auto)", {"sci_mode": None}), + ("Combined: p=1, t=50, e=2", {"precision": 1, "threshold": 50, "edgeitems": 2}), + ( + "Combined: p=6, t=100, e=1, sci=True", + {"precision": 6, "threshold": 100, "edgeitems": 1, "sci_mode": True}, + ), + ] + + for case_name, options in test_cases: + print(f"\n{'=' * 70}") + print(f"Test Case: {case_name}") + print(f" Options: {options}") + print(f"{'=' * 70}") + + # Set print options + infinicore.set_printoptions(**options) + + # Print all test tensors + for tensor_name, tensor in test_tensors.items(): + print(f"\n[{tensor_name}]:") + print(tensor) + + # Reset to defaults + infinicore.set_printoptions( + precision=-1, threshold=1000, edgeitems=3, linewidth=80, sci_mode=None + ) + + +def func9_print_temporary_options(): + """Test that temporary print options work correctly and don't affect global settings.""" + print(f"\n{'=' * 70}") + print("Testing temporary print options (context manager)") + print(f"{'=' * 70}") + + # Set initial global print options + infinicore.set_printoptions( + precision=4, threshold=1000, edgeitems=3, linewidth=80, sci_mode=None + ) + + # Create test tensor + test_data = [[1.211, 2.389, 3.89], [4.569, 5.689, 6.789], [7.89, 8.9, 9.0]] + t_small = infinicore.from_list(test_data, dtype=infinicore.float64) + + # Verify initial settings + print("Tensor output:") + print(t_small) + + # Enter context with temporary settings + with infinicore.printoptions( + precision=2, threshold=50, edgeitems=2, linewidth=40, sci_mode=True + ): + print("Tensor output (with temporary settings):") + print(t_small) + + # Verify global settings are restored + print("Tensor output (should match before context):") + print(t_small) + + if __name__ == "__main__": test() test2() @@ -272,3 +413,6 @@ def func6_initialize_device_relationship(): test4_to() test5_bf16() func6_initialize_device_relationship() + func7_print_different_data_types() + func8_print_options() + func9_print_temporary_options() diff --git a/xmake/test.lua b/xmake/test.lua index 002083e1d..e9dc86cd4 100644 --- a/xmake/test.lua +++ b/xmake/test.lua @@ -66,7 +66,7 @@ target_end() target("infinicore-test") set_kind("binary") - add_deps("infiniop", "infinirt", "infiniccl") + add_deps("infiniop", "infinirt", "infiniccl", "infinicore_cpp_api") set_default(false) set_languages("cxx17")