diff --git a/ast/node.go b/ast/node.go index f17fd7e52..198efa59b 100644 --- a/ast/node.go +++ b/ast/node.go @@ -104,6 +104,12 @@ type StringNode struct { Value string // Value of the string. } +// BytesNode represents a byte slice. +type BytesNode struct { + base + Value []byte // Value of the byte slice. +} + // ConstantNode represents a constant. // Constants are predefined values like nil, true, false, array, map, etc. // The parser.Parse will never generate ConstantNode, it is only generated diff --git a/ast/print.go b/ast/print.go index 527a5b99b..1c197445e 100644 --- a/ast/print.go +++ b/ast/print.go @@ -33,6 +33,10 @@ func (n *StringNode) String() string { return fmt.Sprintf("%q", n.Value) } +func (n *BytesNode) String() string { + return fmt.Sprintf("b%q", n.Value) +} + func (n *ConstantNode) String() string { if n.Value == nil { return "nil" diff --git a/ast/visitor.go b/ast/visitor.go index 72cd6366b..ef23758e1 100644 --- a/ast/visitor.go +++ b/ast/visitor.go @@ -17,6 +17,7 @@ func Walk(node *Node, v Visitor) { case *FloatNode: case *BoolNode: case *StringNode: + case *BytesNode: case *ConstantNode: case *UnaryNode: Walk(&n.Node, v) diff --git a/checker/checker.go b/checker/checker.go index 2a9e234a1..b78afc2b4 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -24,6 +24,7 @@ var ( mapType = reflect.TypeOf(map[string]any{}) timeType = reflect.TypeOf(time.Time{}) durationType = reflect.TypeOf(time.Duration(0)) + byteSliceType = reflect.TypeOf([]byte(nil)) anyTypeSlice = []reflect.Type{anyType} ) @@ -194,6 +195,8 @@ func (v *Checker) visit(node ast.Node) Nature { nt = v.config.NtCache.FromType(boolType) case *ast.StringNode: nt = v.config.NtCache.FromType(stringType) + case *ast.BytesNode: + nt = v.config.NtCache.FromType(byteSliceType) case *ast.ConstantNode: nt = v.config.NtCache.FromType(reflect.TypeOf(n.Value)) case *ast.UnaryNode: diff --git a/compiler/compiler.go b/compiler/compiler.go index 5de89f54c..951385cdb 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -254,6 +254,8 @@ func (c *compiler) compile(node ast.Node) { c.BoolNode(n) case *ast.StringNode: c.StringNode(n) + case *ast.BytesNode: + c.BytesNode(n) case *ast.ConstantNode: c.ConstantNode(n) case *ast.UnaryNode: @@ -410,6 +412,10 @@ func (c *compiler) StringNode(node *ast.StringNode) { c.emitPush(node.Value) } +func (c *compiler) BytesNode(node *ast.BytesNode) { + c.emitPush(node.Value) +} + func (c *compiler) ConstantNode(node *ast.ConstantNode) { if node.Value == nil { c.emit(OpNil) diff --git a/docs/language-definition.md b/docs/language-definition.md index 5e530a57d..69efbdfa9 100644 --- a/docs/language-definition.md +++ b/docs/language-definition.md @@ -53,6 +53,12 @@ nil + + Bytes + + b"hello", b'\xff\x00' + + ### Strings @@ -73,6 +79,38 @@ World` Backticks strings are raw strings, they do not support escape sequences. +### Bytes + +Bytes literals are represented by string literals preceded by a `b` or `B` character. +The bytes literal returns a `[]byte` value. + +```expr +b"abc" // []byte{97, 98, 99} +``` + +Non-ASCII characters are UTF-8 encoded: + +```expr +b"ÿ" // []byte{195, 191} - UTF-8 encoding of ÿ +``` + +Bytes literals support escape sequences for specifying arbitrary byte values: + +- `\xNN` - hexadecimal escape (2 hex digits, value 0-255) +- `\NNN` - octal escape (3 octal digits, value 0-377) +- `\n`, `\t`, `\r`, etc. - standard escape sequences + +```expr +b"\xff" // []byte{255} +b"\x00\x01" // []byte{0, 1} +b"\101" // []byte{65} - octal for 'A' +``` + +:::note +Unlike string literals, bytes literals do not support `\u` or `\U` Unicode escapes. +Use `\x` escapes for arbitrary byte values. +::: + ## Operators diff --git a/expr_test.go b/expr_test.go index ba1f001ec..28b2c54be 100644 --- a/expr_test.go +++ b/expr_test.go @@ -70,6 +70,19 @@ func ExampleCompile() { // Output: true } +func ExampleEval_bytes_literal() { + // Bytes literal returns []byte. + output, err := expr.Eval(`b"abc"`, nil) + if err != nil { + fmt.Printf("%v", err) + return + } + + fmt.Printf("%v", output) + + // Output: [97 98 99] +} + func TestDisableIfOperator_AllowsIfFunction(t *testing.T) { env := map[string]any{ "if": func(x int) int { return x + 1 }, @@ -2929,3 +2942,60 @@ func TestDisableShortCircuit(t *testing.T) { assert.Equal(t, 3, count) assert.True(t, got.(bool)) } + +func TestBytesLiteral(t *testing.T) { + tests := []struct { + code string + want []byte + }{ + {`b"hello"`, []byte("hello")}, + {`b'world'`, []byte("world")}, + {`b""`, []byte{}}, + {`b'\x00\xff'`, []byte{0, 255}}, + {`b"\x41\x42\x43"`, []byte("ABC")}, + {`b'\101\102\103'`, []byte("ABC")}, + {`b'\n\t\r'`, []byte{'\n', '\t', '\r'}}, + {`b'hello\x00world'`, []byte("hello\x00world")}, + {`b"ÿ"`, []byte{0xc3, 0xbf}}, // UTF-8 encoding of ÿ + } + + for _, tt := range tests { + t.Run(tt.code, func(t *testing.T) { + program, err := expr.Compile(tt.code) + require.NoError(t, err) + + output, err := expr.Run(program, nil) + require.NoError(t, err) + assert.Equal(t, tt.want, output) + }) + } +} + +func TestBytesLiteral_type(t *testing.T) { + env := map[string]any{ + "data": []byte("test"), + } + + // Verify bytes literal has []byte type and can be compared with []byte + program, err := expr.Compile(`data == b"test"`, expr.Env(env)) + require.NoError(t, err) + + output, err := expr.Run(program, env) + require.NoError(t, err) + assert.Equal(t, true, output) +} + +func TestBytesLiteral_errors(t *testing.T) { + // \u and \U escapes should not be allowed in bytes literals + errorCases := []string{ + `b'\u0041'`, + `b"\U00000041"`, + } + + for _, code := range errorCases { + t.Run(code, func(t *testing.T) { + _, err := expr.Compile(code) + require.Error(t, err) + }) + } +} diff --git a/parser/lexer/lexer_test.go b/parser/lexer/lexer_test.go index f46870e9b..6212d066f 100644 --- a/parser/lexer/lexer_test.go +++ b/parser/lexer/lexer_test.go @@ -299,6 +299,52 @@ func TestLex(t *testing.T) { {Kind: EOF}, }, }, + { + `b"hello" b'world'`, + []Token{ + {Kind: Bytes, Value: "hello"}, + {Kind: Bytes, Value: "world"}, + {Kind: EOF}, + }, + }, + { + `b"\x00\xff" b'\x41\x42\x43'`, + []Token{ + {Kind: Bytes, Value: "\x00\xff"}, + {Kind: Bytes, Value: "ABC"}, + {Kind: EOF}, + }, + }, + { + `b"\101\102\103" b'\n\t\r'`, + []Token{ + {Kind: Bytes, Value: "ABC"}, + {Kind: Bytes, Value: "\n\t\r"}, + {Kind: EOF}, + }, + }, + { + `b""`, + []Token{ + {Kind: Bytes, Value: ""}, + {Kind: EOF}, + }, + }, + { + `B"hello" B'world'`, + []Token{ + {Kind: Bytes, Value: "hello"}, + {Kind: Bytes, Value: "world"}, + {Kind: EOF}, + }, + }, + { + `b"ÿ"`, + []Token{ + {Kind: Bytes, Value: "\xc3\xbf"}, + {Kind: EOF}, + }, + }, } for _, test := range tests { @@ -380,6 +426,16 @@ früh ♥︎ unrecognized character: U+2665 '♥' (1:6) | früh ♥︎ | .....^ + +b"\u0041" +unable to unescape string (1:9) + | b"\u0041" + | ........^ + +b'\U00000041' +unable to unescape string (1:13) + | b'\U00000041' + | ............^ ` func TestLex_error(t *testing.T) { diff --git a/parser/lexer/state.go b/parser/lexer/state.go index 91857eade..d606258b2 100644 --- a/parser/lexer/state.go +++ b/parser/lexer/state.go @@ -25,6 +25,14 @@ func root(l *Lexer) stateFn { l.emitValue(String, str) case r == '`': l.scanRawString(r) + case (r == 'b' || r == 'B') && (l.peek() == '\'' || l.peek() == '"'): + quote := l.next() + l.scanString(quote) + str, err := unescapeBytes(l.word()[1:]) // skip 'b' + if err != nil { + l.error("%v", err) + } + l.emitValue(Bytes, str) case '0' <= r && r <= '9': l.backup() return number diff --git a/parser/lexer/token.go b/parser/lexer/token.go index c809c690e..1041784e7 100644 --- a/parser/lexer/token.go +++ b/parser/lexer/token.go @@ -12,6 +12,7 @@ const ( Identifier Kind = "Identifier" Number Kind = "Number" String Kind = "String" + Bytes Kind = "Bytes" Operator Kind = "Operator" Bracket Kind = "Bracket" EOF Kind = "EOF" diff --git a/parser/lexer/utils.go b/parser/lexer/utils.go index 6aa088ae3..13cb5f314 100644 --- a/parser/lexer/utils.go +++ b/parser/lexer/utils.go @@ -54,6 +54,49 @@ func unescape(value string) (string, error) { return buf.String(), nil } +// unescapeBytes takes a quoted string, unquotes, and unescapes it as bytes. +func unescapeBytes(value string) (string, error) { + // All strings normalize newlines to the \n representation. + value = newlineNormalizer.Replace(value) + n := len(value) + + // Nothing to unescape / decode. + if n < 2 { + return value, fmt.Errorf("unable to unescape string") + } + + // Quoted string of some form, must have same first and last char. + if value[0] != value[n-1] || (value[0] != '"' && value[0] != '\'') { + return value, fmt.Errorf("unable to unescape string") + } + + value = value[1 : n-1] + + // The string contains escape characters. + // The following logic is adapted from `strconv/quote.go` + var runeTmp [utf8.UTFMax]byte + size := 3 * uint64(n) / 2 + if size >= math.MaxInt { + return "", fmt.Errorf("too large string") + } + buf := new(strings.Builder) + buf.Grow(int(size)) + for len(value) > 0 { + c, multibyte, rest, err := unescapeByteChar(value) + if err != nil { + return "", err + } + value = rest + if c < utf8.RuneSelf || !multibyte { + buf.WriteByte(byte(c)) + } else { + n := utf8.EncodeRune(runeTmp[:], c) + buf.Write(runeTmp[:n]) + } + } + return buf.String(), nil +} + // unescapeChar takes a string input and returns the following info: // // value - the escaped unicode rune at the front of the string. @@ -208,6 +251,91 @@ func unescapeChar(s string) (value rune, multibyte bool, tail string, err error) return } +// unescapeByteChar unescapes a single character or escape sequence from a bytes literal. +// Unlike unescapeChar, this only supports byte-level escapes (\x, octal) and rejects +// Unicode escapes (\u, \U) since bytes literals represent raw byte sequences. +// +// Note: We cannot use strconv.UnquoteChar here because it interprets \x and octal +// escapes as Unicode codepoints (e.g., \xff → codepoint 255 → 2 UTF-8 bytes), +// whereas bytes literals require them as raw byte values (\xff → single byte 255). +func unescapeByteChar(s string) (value rune, multibyte bool, tail string, err error) { + // Non-escape: return the character as-is. + // For bytes literals, we accept UTF-8 sequences but they get encoded back to bytes. + c := s[0] + if c != '\\' { + if c >= utf8.RuneSelf { + r, size := utf8.DecodeRuneInString(s) + return r, true, s[size:], nil + } + return rune(c), false, s[1:], nil + } + + // Escape sequence: need at least one more character. + if len(s) <= 1 { + return 0, false, "", fmt.Errorf("unable to unescape string, found '\\' as last character") + } + + c = s[1] + s = s[2:] + + switch c { + // Simple escape sequences + case 'a': + return '\a', false, s, nil + case 'b': + return '\b', false, s, nil + case 'f': + return '\f', false, s, nil + case 'n': + return '\n', false, s, nil + case 'r': + return '\r', false, s, nil + case 't': + return '\t', false, s, nil + case 'v': + return '\v', false, s, nil + case '\\': + return '\\', false, s, nil + case '\'': + return '\'', false, s, nil + case '"': + return '"', false, s, nil + case '`': + return '`', false, s, nil + case '?': + return '?', false, s, nil + + // Hex escape: \xNN (exactly 2 hex digits, value 0-255) + case 'x', 'X': + if len(s) < 2 { + return 0, false, "", fmt.Errorf("unable to unescape string") + } + hi, ok1 := unhex(s[0]) + lo, ok2 := unhex(s[1]) + if !ok1 || !ok2 { + return 0, false, "", fmt.Errorf("unable to unescape string") + } + return hi<<4 | lo, false, s[2:], nil + + // Octal escape: \NNN (3 octal digits, value 0-255) + case '0', '1', '2', '3': + if len(s) < 2 { + return 0, false, "", fmt.Errorf("unable to unescape octal sequence in string") + } + if s[0] < '0' || s[0] > '7' || s[1] < '0' || s[1] > '7' { + return 0, false, "", fmt.Errorf("unable to unescape octal sequence in string") + } + v := rune(c-'0')*64 + rune(s[0]-'0')*8 + rune(s[1]-'0') + if v > 255 { + return 0, false, "", fmt.Errorf("unable to unescape string") + } + return v, false, s[2:], nil + + default: + return 0, false, "", fmt.Errorf("unable to unescape string") + } +} + func unhex(b byte) (rune, bool) { c := rune(b) switch { diff --git a/parser/parser.go b/parser/parser.go index 9ccf47830..a0c6d4491 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -530,6 +530,13 @@ func (p *Parser) parseSecondary() Node { return nil } + case Bytes: + p.next() + node = p.createNode(&BytesNode{Value: []byte(token.Value)}, token.Location) + if node == nil { + return nil + } + default: if token.Is(Bracket, "[") { node = p.parseArrayExpression(token) diff --git a/parser/parser_test.go b/parser/parser_test.go index 85c465750..963a23050 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -87,6 +87,14 @@ world`}, "nil", &NilNode{}, }, + { + `b"hello"`, + &BytesNode{Value: []byte("hello")}, + }, + { + `b'\xff\x00'`, + &BytesNode{Value: []byte{255, 0}}, + }, { "-3", &UnaryNode{Operator: "-",