Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions include/infinicore/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <algorithm>
#include <cassert>
#include <functional>
#include <iostream>
#include <memory>
#include <vector>

Expand Down Expand Up @@ -90,6 +91,7 @@ class Tensor {
Tensor(std::shared_ptr<TensorImpl> impl) : impl_(std::move(impl)) {}
std::shared_ptr<TensorImpl> impl_;
friend class TensorImpl;
friend std::ostream &operator<<(std::ostream &os, const Tensor &tensor);
};

class TensorImpl : public std::enable_shared_from_this<TensorImpl> {
Expand Down Expand Up @@ -302,3 +304,87 @@ class TensorImpl : public std::enable_shared_from_this<TensorImpl> {
};

} // namespace infinicore

namespace infinicore {

namespace print_options {
/**
* @brief Sets the line width. After \a line_width chars, a new line is added.
* @param line_width The line width
*/
void set_line_width(int line_width);

/**
* @brief Sets the threshold after which summarization is triggered (default: 1000).
* @param threshold The number of elements in the tensor that triggers summarization in the output
*/
void set_threshold(int threshold);

/**
* @brief Sets the number of edge items.
* If the summarization is triggered, this value defines how many items of each dimension are printed.
* @param edge_items The number of edge items
*/
void set_edge_items(int edge_items);

/**
* @brief Sets the precision for printing floating point values.
* @param precision The number of digits for floating point output
*/

void set_precision(int precision);

/**
* @brief Sets the sci mode of the floating point values when printing an Tensor.
* @param sci_mode The sci mode: -1 for auto decision, 0 to disable, 1 to enable
*/

void set_sci_mode(int sci_mode); // -1: auto, 0: disable, 1: enable

#define DEFINE_LOCAL_PRINT_OPTION(NAME) \
class NAME { \
public: \
NAME(int value) : m_value(value) { id(); } \
static int id() { \
static int id = std::ios_base::xalloc(); \
return id; \
} \
int value() const { return m_value; } \
\
private: \
int m_value; \
}; \
\
inline std::ostream &operator<<(std::ostream &out, const NAME &n) { \
out.iword(NAME::id()) = n.value(); \
return out; \
}

/**
* @class line_width
* io manipulator used to set the width of the lines when printing an Tensor.
*
* @code{.cpp}
* using po = infinicore::print_options;
* std::cout << po::line_width(100) << tensor << std::endl;
* @endcode
*/
DEFINE_LOCAL_PRINT_OPTION(line_width)

/**
* io manipulator used to set the threshold after which summarization is triggered.
*/
DEFINE_LOCAL_PRINT_OPTION(threshold)

/**
* io manipulator used to set the number of egde items if the summarization is triggered.
*/
DEFINE_LOCAL_PRINT_OPTION(edge_items)

/**
* io manipulator used to set the precision of the floating point values when printing an Tensor.
*/
DEFINE_LOCAL_PRINT_OPTION(precision)

} // namespace print_options
} // namespace infinicore
3 changes: 3 additions & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import infinicore.context as context
import infinicore.nn as nn
from infinicore._tensor_str import printoptions, set_printoptions
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch中把print的一些 函数调用放到了_tensor_str.py文件中,于是也新建了一个 _tensor_str


# Import context functions
from infinicore.context import (
Expand Down Expand Up @@ -134,6 +135,8 @@
"strided_empty",
"strided_from_blob",
"zeros",
"set_printoptions",
"printoptions",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch中存在一下两个函数
"set_printoptions":全局配置的函数
"printoptions": 临时配置的函数

]

use_ntops = False
Expand Down
114 changes: 114 additions & 0 deletions python/infinicore/_tensor_str.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import contextlib
import dataclasses
from typing import Any, Optional

from infinicore.lib import _infinicore


@dataclasses.dataclass
class __PrinterOptions:
precision: int = 4
threshold: float = 1000
edgeitems: int = 3
linewidth: int = 80
sci_mode: Optional[bool] = None


PRINT_OPTS = __PrinterOptions()


def set_printoptions(
precision=None,
threshold=None,
edgeitems=None,
linewidth=None,
profile=None,
sci_mode=None,
):
r"""Set options for printing.
Args:
precision: Number of digits of precision for floating point output (default = 4).
threshold: Total number of array elements which trigger summarization rather than full `repr` (default = 1000).
edgeitems: Number of array items in summary at beginning and end of each dimension (default = 3).
linewidth: The number of characters per line (default = 80).
profile: Sane defaults for pretty printing. Can override with any of the above options. (any one of `default`, `short`, `full`)
sci_mode: Enable (True) or disable (False) scientific notation.
If None (default) is specified, the value is automatically chosen by the framework.

Example::
>>> # Limit the precision of elements
>>> torch.set_printoptions(precision=2)
>>> torch.tensor([1.12345])
tensor([1.12])
"""
if profile is not None:
if profile == "default":
PRINT_OPTS.precision = 4
PRINT_OPTS.threshold = 1000
PRINT_OPTS.edgeitems = 3
PRINT_OPTS.linewidth = 80
elif profile == "short":
PRINT_OPTS.precision = 2
PRINT_OPTS.threshold = 1000
PRINT_OPTS.edgeitems = 2
PRINT_OPTS.linewidth = 80
elif profile == "full":
PRINT_OPTS.precision = 4
PRINT_OPTS.threshold = 2147483647 # CPP_INT32_MAX
PRINT_OPTS.edgeitems = 3
PRINT_OPTS.linewidth = 80
else:
raise ValueError(
f"Invalid profile: {profile}. the profile must be one of 'default', 'short', 'full'"
)

if precision is not None:
PRINT_OPTS.precision = precision
if threshold is not None:
PRINT_OPTS.threshold = threshold
if edgeitems is not None:
PRINT_OPTS.edgeitems = edgeitems
if linewidth is not None:
PRINT_OPTS.linewidth = linewidth
PRINT_OPTS.sci_mode = sci_mode

_infinicore.set_printoptions(
PRINT_OPTS.precision,
PRINT_OPTS.threshold,
PRINT_OPTS.edgeitems,
PRINT_OPTS.linewidth,
PRINT_OPTS.sci_mode,
)


def get_printoptions() -> dict[str, Any]:
r"""Gets the current options for printing, as a dictionary that
can be passed as ``**kwargs`` to set_printoptions().
"""
return dataclasses.asdict(PRINT_OPTS)


@contextlib.contextmanager
def printoptions(
precision=None, threshold=None, edgeitems=None, linewidth=None, sci_mode=None
):
r"""Context manager that temporarily changes the print options."""
old_kwargs = get_printoptions()

set_printoptions(
precision=precision,
threshold=threshold,
edgeitems=edgeitems,
linewidth=linewidth,
sci_mode=sci_mode,
)
try:
yield
finally:
set_printoptions(**old_kwargs)


def _str(self):
cpp_tensor_str = self._underlying.__str__()
py_dtype_str = "dtype=" + self.dtype.__repr__()
return cpp_tensor_str.split("dtype=INFINI.")[0] + py_dtype_str + ")\n"
4 changes: 4 additions & 0 deletions python/infinicore/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import infinicore.dtype
from infinicore.lib import _infinicore

from ._tensor_str import _str
from .utils import (
infinicore_to_numpy_dtype,
numpy_to_infinicore_dtype,
Expand Down Expand Up @@ -130,6 +131,9 @@ def __mul__(self, other):
def narrow(self, dim, start, length):
return infinicore.narrow(self, dim, start, length)

def __repr__(self):
return _str(self)


def empty(size, *, dtype=None, device=None, pin_memory=False):
return Tensor(
Expand Down
6 changes: 6 additions & 0 deletions python/infinicore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def to_torch_dtype(infini_dtype):
return torch.float16
elif infini_dtype == infinicore.float32:
return torch.float32
elif infini_dtype == infinicore.float64:
return torch.float64
elif infini_dtype == infinicore.bfloat16:
return torch.bfloat16
elif infini_dtype == infinicore.int8:
Expand All @@ -23,6 +25,8 @@ def to_torch_dtype(infini_dtype):
return torch.int64
elif infini_dtype == infinicore.uint8:
return torch.uint8
elif infini_dtype == infinicore.bool:
return torch.bool
else:
raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}")

Expand Down Expand Up @@ -93,5 +97,7 @@ def infinicore_to_numpy_dtype(infini_dtype):
return np.int64
elif infini_dtype == infinicore.uint8:
return np.uint8
elif infini_dtype == infinicore.bool:
return np.bool_
else:
raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}")
16 changes: 8 additions & 8 deletions src/infinicore-test/test_nn_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -898,8 +898,8 @@ TestResult NNModuleTest::testModuleLinear() {

// Test forward with residual connection
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c++的linear有个版移除了一个函数,导致c++测试编译不过。于是,在src/infinicore-test/test_nn_module.cc中注释或删除了传递的参数

spdlog::info("Testing Linear forward with residual connection");
auto residual = infinicore::Tensor::ones({2, 4}, infinicore::DataType::F32, infinicore::Device());
auto output_with_residual = m1.forward(input1, residual);
// auto residual = infinicore::Tensor::ones({2, 4}, infinicore::DataType::F32, infinicore::Device());
auto output_with_residual = m1.forward(input1);
if (output_with_residual->shape() != std::vector<size_t>({2, 4})) {
spdlog::error("Linear output with residual shape mismatch. Expected {{2, 4}}, got different shape");
return false;
Expand All @@ -911,10 +911,10 @@ TestResult NNModuleTest::testModuleLinear() {

// Create test data with known values for verification
auto test_input = infinicore::Tensor::ones({2, 8}, infinicore::DataType::F32, infinicore::Device());
auto test_residual = infinicore::Tensor::ones({2, 4}, infinicore::DataType::F32, infinicore::Device());
// auto test_residual = infinicore::Tensor::ones({2, 4}, infinicore::DataType::F32, infinicore::Device());

// Get InfiniCore result
auto infinicore_output = m1.forward(test_input, test_residual);
auto infinicore_output = m1.forward(test_input);

// Compute naive result: output = input @ weight.T + bias + residual
auto naive_output = infinicore::Tensor::empty({2, 4}, infinicore::DataType::F32, infinicore::Device());
Expand All @@ -935,7 +935,7 @@ TestResult NNModuleTest::testModuleLinear() {
infinicore::op::add_(naive_output, matmul_result, bias_view);

// Add residual
infinicore::op::add_(naive_output, naive_output, test_residual);
// infinicore::op::add_(naive_output, naive_output, test_residual);

// Compare results with actual value checking
if (infinicore_output->shape() != naive_output->shape()) {
Expand All @@ -956,10 +956,10 @@ TestResult NNModuleTest::testModuleLinear() {
// Test computation correctness without bias (using m2)
spdlog::info("Testing computation correctness without bias");
auto test_input_no_bias = infinicore::Tensor::ones({1, 16}, infinicore::DataType::F32, infinicore::Device());
auto test_residual_no_bias = infinicore::Tensor::ones({1, 3}, infinicore::DataType::F32, infinicore::Device());
// auto test_residual_no_bias = infinicore::Tensor::ones({1, 3}, infinicore::DataType::F32, infinicore::Device());

// Get InfiniCore result (no bias)
auto infinicore_output_no_bias = m2.forward(test_input_no_bias, test_residual_no_bias);
auto infinicore_output_no_bias = m2.forward(test_input_no_bias);

// Compute naive result without bias: output = input @ weight.T + residual
auto naive_output_no_bias = infinicore::Tensor::empty({1, 3}, infinicore::DataType::F32, infinicore::Device());
Expand All @@ -970,7 +970,7 @@ TestResult NNModuleTest::testModuleLinear() {
auto matmul_result_no_bias = infinicore::op::matmul(test_input_no_bias, weight_t_no_bias); // [1, 3]

// Add residual
infinicore::op::add_(naive_output_no_bias, matmul_result_no_bias, test_residual_no_bias);
// infinicore::op::add_(naive_output_no_bias, matmul_result_no_bias, test_residual_no_bias);

// Compare results with actual value checking
if (infinicore_output_no_bias->shape() != naive_output_no_bias->shape()) {
Expand Down
Loading