// 771. Pratt Parser for Operator Precedence Expressions
// Handles: (1+2)*3, unary minus, right-assoc ^
// โโ Token โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
#[derive(Debug, Clone, PartialEq)]
pub enum Token {
Num(f64),
Plus, Minus, Star, Slash, Caret,
LParen, RParen,
Eof,
}
// โโ AST โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
#[derive(Debug, Clone)]
pub enum Expr {
Num(f64),
Unary { op: char, operand: Box<Expr> },
Binary { op: char, left: Box<Expr>, right: Box<Expr> },
}
impl Expr {
pub fn eval(&self) -> f64 {
match self {
Expr::Num(n) => *n,
Expr::Unary { op: '-', operand } => -operand.eval(),
Expr::Unary { operand, .. } => operand.eval(),
Expr::Binary { op, left, right } => {
let (l, r) = (left.eval(), right.eval());
match op {
'+' => l + r, '-' => l - r,
'*' => l * r, '/' => l / r,
'^' => l.powf(r),
_ => panic!("unknown op"),
}
}
}
}
}
// โโ Lexer โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
pub fn tokenize(s: &str) -> Vec<Token> {
let mut tokens = Vec::new();
let chars: Vec<char> = s.chars().collect();
let mut i = 0;
while i < chars.len() {
match chars[i] {
' ' | '\t' => { i += 1; }
'+' => { tokens.push(Token::Plus); i += 1; }
'-' => { tokens.push(Token::Minus); i += 1; }
'*' => { tokens.push(Token::Star); i += 1; }
'/' => { tokens.push(Token::Slash); i += 1; }
'^' => { tokens.push(Token::Caret); i += 1; }
'(' => { tokens.push(Token::LParen); i += 1; }
')' => { tokens.push(Token::RParen); i += 1; }
c if c.is_ascii_digit() || c == '.' => {
let start = i;
while i < chars.len() && (chars[i].is_ascii_digit() || chars[i] == '.') {
i += 1;
}
let num: f64 = chars[start..i].iter().collect::<String>().parse().unwrap();
tokens.push(Token::Num(num));
}
_ => { i += 1; }
}
}
tokens.push(Token::Eof);
tokens
}
// โโ Pratt parser โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
pub struct Parser {
tokens: Vec<Token>,
pos: usize,
}
#[derive(Debug)]
pub struct ParseError(pub String);
impl Parser {
pub fn new(tokens: Vec<Token>) -> Self { Self { tokens, pos: 0 } }
fn peek(&self) -> &Token {
self.tokens.get(self.pos).unwrap_or(&Token::Eof)
}
fn consume(&mut self) -> Token {
let t = self.tokens.get(self.pos).cloned().unwrap_or(Token::Eof);
self.pos += 1;
t
}
/// Returns (left_bp, right_bp) for infix operators
fn infix_bp(tok: &Token) -> Option<(u8, u8)> {
match tok {
Token::Plus | Token::Minus => Some((10, 11)),
Token::Star | Token::Slash => Some((20, 21)),
Token::Caret => Some((30, 29)), // right-assoc
_ => None,
}
}
fn op_char(tok: &Token) -> char {
match tok {
Token::Plus => '+', Token::Minus => '-',
Token::Star => '*', Token::Slash => '/',
Token::Caret => '^',
_ => '?',
}
}
fn parse_nud(&mut self) -> Result<Expr, ParseError> {
match self.consume() {
Token::Num(n) => Ok(Expr::Num(n)),
Token::Minus => {
let operand = self.parse_bp(25)?;
Ok(Expr::Unary { op: '-', operand: Box::new(operand) })
}
Token::LParen => {
let e = self.parse_bp(0)?;
if self.consume() != Token::RParen {
return Err(ParseError("expected ')'".into()));
}
Ok(e)
}
t => Err(ParseError(format!("unexpected token: {t:?}"))),
}
}
pub fn parse_bp(&mut self, min_bp: u8) -> Result<Expr, ParseError> {
let mut left = self.parse_nud()?;
loop {
let tok = self.peek().clone();
match Self::infix_bp(&tok) {
Some((lbp, rbp)) if lbp > min_bp => {
self.consume();
let right = self.parse_bp(rbp)?;
let op = Self::op_char(&tok);
left = Expr::Binary { op, left: Box::new(left), right: Box::new(right) };
}
_ => break,
}
}
Ok(left)
}
}
pub fn parse(input: &str) -> Result<Expr, ParseError> {
Parser::new(tokenize(input)).parse_bp(0)
}
fn main() {
let tests: &[(&str, f64)] = &[
("(1 + 2) * 3", 9.0),
("2 ^ 3 ^ 2", 512.0), // right-assoc: 2^(3^2) = 2^9
("-2 * 3", -6.0),
("1 + 2 * 3 - 4 / 2", 5.0),
("(10 - 3) / (2 + 5)", 1.0),
];
for (expr, expected) in tests {
let result = parse(expr).unwrap().eval();
let ok = if (result - expected).abs() < 1e-9 { "โ" } else { "โ" };
println!("{expr:30} = {result:8.2} (expected {expected:.2}) {ok}");
}
}
#[cfg(test)]
mod tests {
use super::*;
fn ev(s: &str) -> f64 { parse(s).unwrap().eval() }
#[test]
fn basic_ops() { assert_eq!(ev("1 + 2"), 3.0); }
#[test]
fn precedence() { assert_eq!(ev("1 + 2 * 3"), 7.0); }
#[test]
fn parens_change_prec() { assert_eq!(ev("(1 + 2) * 3"), 9.0); }
#[test]
fn right_assoc_power() { assert!((ev("2 ^ 3 ^ 2") - 512.0).abs() < 1e-9); }
#[test]
fn unary_minus() { assert_eq!(ev("-3 + 5"), 2.0); }
#[test]
fn unary_in_parens() { assert_eq!(ev("(-3) * 2"), -6.0); }
}
(* Pratt parser in OCaml โ handles (1+2)*3, unary minus, right-assoc ^ *)
(* โโ Tokens โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ *)
type token = TNum of float | TPlus | TMinus | TStar | TSlash | TCaret
| TLParen | TRParen | TEof
(* โโ AST โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ *)
type expr =
| Num of float
| Unary of char * expr
| Binary of char * expr * expr
let rec eval = function
| Num n -> n
| Unary ('-', e) -> -. (eval e)
| Unary (_, e) -> eval e
| Binary ('+', a, b) -> eval a +. eval b
| Binary ('-', a, b) -> eval a -. eval b
| Binary ('*', a, b) -> eval a *. eval b
| Binary ('/', a, b) -> eval a /. eval b
| Binary ('^', a, b) -> Float.pow (eval a) (eval b)
| Binary _ -> 0.0
(* โโ Lexer โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ *)
let tokenize s =
let tokens = ref [] in
let i = ref 0 in
let len = String.length s in
while !i < len do
let c = s.[!i] in
incr i;
match c with
| ' ' | '\t' -> ()
| '+' -> tokens := TPlus :: !tokens
| '-' -> tokens := TMinus :: !tokens
| '*' -> tokens := TStar :: !tokens
| '/' -> tokens := TSlash :: !tokens
| '^' -> tokens := TCaret :: !tokens
| '(' -> tokens := TLParen :: !tokens
| ')' -> tokens := TRParen :: !tokens
| c when c >= '0' && c <= '9' ->
let start = !i - 1 in
while !i < len && (s.[!i] >= '0' && s.[!i] <= '9' || s.[!i] = '.') do incr i done;
tokens := TNum (float_of_string (String.sub s start (!i - start))) :: !tokens
| _ -> ()
done;
List.rev (TEof :: !tokens)
(* โโ Pratt parser โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ *)
let tokens = ref []
let current () = match !tokens with t :: _ -> t | [] -> TEof
let consume () = match !tokens with _ :: rest -> tokens := rest | [] -> ()
let infix_bp = function
| TPlus | TMinus -> (10, 11)
| TStar | TSlash -> (20, 21)
| TCaret -> (30, 29) (* right-associative: right bp < left bp *)
| _ -> (0, 0)
let op_char = function
| TPlus -> '+' | TMinus -> '-' | TStar -> '*' | TSlash -> '/' | TCaret -> '^'
| _ -> '?'
let rec parse_bp min_bp =
let left = ref (parse_nud ()) in
let continue = ref true in
while !continue do
let tok = current () in
let (lbp, rbp) = infix_bp tok in
if lbp <= min_bp then continue := false
else begin
consume ();
let right = parse_bp rbp in
left := Binary (op_char tok, !left, right)
end
done;
!left
and parse_nud () =
match current () with
| TNum n -> consume (); Num n
| TMinus ->
consume ();
let e = parse_bp 25 in (* higher than * to bind tight *)
Unary ('-', e)
| TLParen ->
consume ();
let e = parse_bp 0 in
(match current () with TRParen -> consume () | _ -> failwith "expected ')'");
e
| t -> failwith (Printf.sprintf "unexpected token: %d" (Obj.tag (Obj.repr t)))
let parse s =
tokens := tokenize s;
parse_bp 0
let () =
let tests = [
"(1 + 2) * 3", (* = 9 *)
"2 ^ 3 ^ 2", (* = 512, right-assoc: 2^(3^2) = 2^9 *)
"-2 * 3", (* = -6 *)
"1 + 2 * 3 - 4 / 2", (* = 5 *)
] in
List.iter (fun s ->
Printf.printf "%s = %g\n" s (eval (parse s))
) tests