diff --git a/Cargo.toml b/Cargo.toml index f0c3c0a..929a2d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,3 @@ name = "monkeyrs" version = "0.1.0" edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] diff --git a/src/ast.rs b/src/ast.rs index 28e73eb..c52628c 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -3,19 +3,122 @@ use std::{ rc::Rc, }; -use crate::token::{Token, TokenType}; +use crate::token::TokenType; -pub trait Node: Debug {} +#[derive(Debug, PartialEq, Eq)] +pub enum NodeType { + Program, + LetStatement, + ReturnStatement, + ExpressionStatement, + IdentifierExpression, + IntegerLiteralExpression, + PrefixExpression, + InfixExpression, + BooleanExpression, + DummyExpression, +} + +#[macro_export] +macro_rules! node_type_eq { + ( $x:expr, $y:expr ) => {{ + if $x.node_type() != $y.node_type() { + false + } else { + match $x.node_type() { + NodeType::Program => unreachable!(), + NodeType::DummyExpression => true, + NodeType::BooleanExpression => { + let __x = unsafe { $x.downcast::() }; + let __y = unsafe { $y.downcast::() }; + __x == __y + } + NodeType::ExpressionStatement => { + let __x = unsafe { $x.downcast::() }; + let __y = unsafe { $y.downcast::() }; + __x == __y + } + NodeType::LetStatement => { + let __x = unsafe { $x.downcast::() }; + let __y = unsafe { $y.downcast::() }; + __x == __y + } + NodeType::ReturnStatement => { + let __x = unsafe { $x.downcast::() }; + let __y = unsafe { $y.downcast::() }; + __x == __y + } + NodeType::InfixExpression => { + let __x = unsafe { $x.downcast::() }; + let __y = unsafe { $y.downcast::() }; + __x == __y + } + NodeType::PrefixExpression => { + let __x = unsafe { $x.downcast::() }; + let __y = unsafe { $y.downcast::() }; + __x == __y + } + NodeType::IdentifierExpression => { + let __x = unsafe { $x.downcast::() }; + let __y = unsafe { $y.downcast::() }; + __x == __y + } + NodeType::IntegerLiteralExpression => { + let __x = unsafe { $x.downcast::() }; + let __y = unsafe { $y.downcast::() }; + __x == __y + } + } + } + }}; +} + +pub trait Node: Debug { + fn node_type(&self) -> NodeType; +} pub trait Statement: Node {} +impl dyn Statement { + unsafe fn downcast(&self) -> &T { + &*(self as *const dyn Statement as *const T) + } +} + pub trait Expression: Node {} +impl dyn Expression { + unsafe fn downcast(&self) -> &T { + &*(self as *const dyn Expression as *const T) + } +} + pub struct Program { pub statements: Vec>, } -impl Node for Program {} +impl PartialEq for Program { + fn eq(&self, other: &Self) -> bool { + if self.statements.len() != other.statements.len() { + return false; + } + let mut other_iter = other.statements.iter(); + for stmt in self.statements.iter() { + let other = other_iter.next().unwrap(); + + if !node_type_eq!(stmt, other) { + return false; + } + } + true + } +} + +impl Node for Program { + fn node_type(&self) -> NodeType { + NodeType::Program + } +} impl Debug for Program { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -33,56 +136,97 @@ impl Debug for Program { #[derive(Debug)] pub struct Let { - pub token: Token, pub name: Identifier, pub value: Rc, } -impl Node for Let {} +impl PartialEq for Let { + fn eq(&self, other: &Self) -> bool { + if self.name != other.name { + return false; + } + + node_type_eq!(self.value, other.value) + } +} + +impl Node for Let { + fn node_type(&self) -> NodeType { + NodeType::LetStatement + } +} impl Statement for Let {} #[derive(Debug)] pub struct Return { - pub token: Token, pub value: Rc, } +impl PartialEq for Return { + fn eq(&self, other: &Self) -> bool { + node_type_eq!(self.value, other.value) + } +} + +impl Node for Return { + fn node_type(&self) -> NodeType { + NodeType::ReturnStatement + } +} + impl Statement for Return {} -impl Node for Return {} - #[derive(Debug)] pub struct ExpressionStatement { - // TODO: probably not needed - pub token: Token, pub expression: Rc, } +impl PartialEq for ExpressionStatement { + fn eq(&self, other: &Self) -> bool { + let expr = &self.expression; + let other = &other.expression; + + node_type_eq!(expr, other) + } +} + +impl Node for ExpressionStatement { + fn node_type(&self) -> NodeType { + NodeType::ExpressionStatement + } +} + impl Statement for ExpressionStatement {} -impl Node for ExpressionStatement {} - -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub struct Identifier { pub token_type: TokenType, pub value: String, } -impl Node for Identifier {} +impl Node for Identifier { + fn node_type(&self) -> NodeType { + NodeType::IdentifierExpression + } +} impl Expression for Identifier {} -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub struct IntegerLiteral { pub value: i64, } -impl Node for IntegerLiteral {} +impl Node for IntegerLiteral { + fn node_type(&self) -> NodeType { + NodeType::IntegerLiteralExpression + } +} impl Expression for IntegerLiteral {} -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub enum PrefixOperator { Minus, Bang, @@ -105,11 +249,25 @@ pub struct Prefix { pub right: Rc, } -impl Node for Prefix {} +impl PartialEq for Prefix { + fn eq(&self, other: &Self) -> bool { + if self.operator != other.operator { + return false; + } + println!("Prefix PartialEq"); + node_type_eq!(self.right, other.right) + } +} + +impl Node for Prefix { + fn node_type(&self) -> NodeType { + NodeType::PrefixExpression + } +} impl Expression for Prefix {} -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub enum InfixOperator { Plus, Minus, @@ -168,13 +326,42 @@ pub struct Infix { pub right: Rc, } -impl Node for Infix {} +impl PartialEq for Infix { + fn eq(&self, other: &Self) -> bool { + self.operator == other.operator + && node_type_eq!(self.left, other.left) + && node_type_eq!(self.right, other.right) + } +} + +impl Node for Infix { + fn node_type(&self) -> NodeType { + NodeType::InfixExpression + } +} impl Expression for Infix {} +#[derive(Debug, PartialEq, Eq)] +pub struct Boolean { + pub value: bool, +} + +impl Node for Boolean { + fn node_type(&self) -> NodeType { + NodeType::BooleanExpression + } +} + +impl Expression for Boolean {} + #[derive(Debug)] pub struct DummyExpression {} -impl Node for DummyExpression {} +impl Node for DummyExpression { + fn node_type(&self) -> NodeType { + NodeType::DummyExpression + } +} impl Expression for DummyExpression {} diff --git a/src/lexer.rs b/src/lexer.rs index 3a403b7..13ded27 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -158,45 +158,45 @@ mod tests { 10 == 10; 10 != 9; " - .to_string(); + .into(); use Token::*; let tests = vec![ Let, - Ident("five".to_string()), + Ident("five".into()), Assign, Int(5), Semicolon, Let, - Ident("ten".to_string()), + Ident("ten".into()), Assign, Int(10), Semicolon, Let, - Ident("add".to_string()), + Ident("add".into()), Assign, Function, Lparen, - Ident("x".to_string()), + Ident("x".into()), Comma, - Ident("y".to_string()), + Ident("y".into()), Rparen, Lbrace, - Ident("x".to_string()), + Ident("x".into()), Plus, - Ident("y".to_string()), + Ident("y".into()), Semicolon, Rbrace, Semicolon, Let, - Ident("result".to_string()), + Ident("result".into()), Assign, - Ident("add".to_string()), + Ident("add".into()), Lparen, - Ident("five".to_string()), + Ident("five".into()), Comma, - Ident("ten".to_string()), + Ident("ten".into()), Rparen, Semicolon, Bang, diff --git a/src/parser.rs b/src/parser.rs index a6efb16..827bda4 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -2,8 +2,8 @@ use std::{fmt, rc::Rc}; use crate::{ ast::{ - DummyExpression, Expression, ExpressionStatement, Identifier, Infix, InfixOperator, - IntegerLiteral, Let, Prefix, PrefixOperator, Program, Return, Statement, + Boolean, DummyExpression, Expression, ExpressionStatement, Identifier, Infix, + InfixOperator, IntegerLiteral, Let, Prefix, PrefixOperator, Program, Return, Statement, }, lexer::Lexer, token::{self, Token, TokenType}, @@ -64,7 +64,7 @@ impl fmt::Display for Error { "expected token `{expected}`, actual token: `{}`", match actual { Some(token) => format!("{token}"), - None => "None".to_string(), + None => "None".into(), }, ), NoPrefixParseFnFound(t) => write!(f, "no prefix parse function found for `{t}`"), @@ -131,12 +131,6 @@ impl Parser { } fn parse_let_statement(&mut self) -> Option> { - let token = if let Some(token) = &self.cur_token { - token.clone() - } else { - return None; - }; - let value = self.expect_peek_ident()?; let name = Identifier { @@ -148,35 +142,33 @@ impl Parser { return None; } - while !self.cur_token_is(TokenType::Semicolon) { - self.next(); - } - let value = Rc::new(DummyExpression {}); + self.next(); - Some(Rc::new(Let { token, name, value })) + // TODO: maybe not unwrap + let value = self.parse_expression(Precedence::Lowest).unwrap(); + self.next(); + + Some(Rc::new(Let { name, value })) } fn parse_return_statement(&mut self) -> Option> { - let token = self.cur_token.clone().unwrap(); + self.next(); - while !self.cur_token_is(TokenType::Semicolon) { - self.next(); - } - let value = Rc::new(DummyExpression {}); + // TODO: maybe not unwrap + let value = self.parse_expression(Precedence::Lowest).unwrap(); + self.next(); - Some(Rc::new(Return { token, value })) + Some(Rc::new(Return { value })) } fn parse_expression_statement(&mut self) -> Option> { - let token = self.cur_token.clone()?; - let expression = self.parse_expression(Precedence::Lowest)?; if self.peek_token_is(TokenType::Semicolon) { self.next(); } - Some(Rc::new(ExpressionStatement { token, expression })) + Some(Rc::new(ExpressionStatement { expression })) } fn parse_expression(&mut self, precedence: Precedence) -> Option> { @@ -237,6 +229,12 @@ impl Parser { Rc::new(Prefix { operator, right }) } + fn parse_boolean(&mut self) -> Rc { + Rc::new(Boolean { + value: self.cur_token_is(TokenType::True), + }) + } + fn parse_infix_expression(&mut self, left: Rc) -> Rc { let token = if let Some(token) = &self.cur_token { token @@ -375,6 +373,8 @@ impl Parser { Bang => Some(Box::new(|parser| { Self::parse_prefix_expression(parser, PrefixOperator::Bang) })), + True => Some(Box::new(Self::parse_boolean)), + False => Some(Box::new(Self::parse_boolean)), _ => None, } } @@ -417,9 +417,12 @@ mod tests { use std::rc::Rc; use crate::{ - ast::{InfixOperator, PrefixOperator, Statement}, + ast::{ + Boolean, ExpressionStatement, Identifier, Infix, InfixOperator, IntegerLiteral, Let, + Prefix, PrefixOperator, Program, Return, Statement, + }, lexer::Lexer, - token::Token, + token::{Token, TokenType}, }; use super::Parser; @@ -430,7 +433,7 @@ mod tests { let y = 10;\ let foobar = 838383;\ " - .to_string(); + .into(); let lexer = Lexer::new(source); @@ -439,14 +442,34 @@ mod tests { let program = parser.parse().unwrap(); check_parser_errors(parser); - assert_eq!(program.statements.len(), 3); - - let expected_identifiers = vec!["x", "y", "foobar"]; - let mut statements_iter = program.statements.iter(); - for tt in expected_identifiers { - let statement = statements_iter.next().unwrap(); - test_let_statement(statement.clone(), tt); - } + assert_eq!( + program, + Program { + statements: vec![ + Rc::new(Let { + name: Identifier { + token_type: TokenType::Let, + value: "x".into() + }, + value: Rc::new(IntegerLiteral { value: 5 }) + }), + Rc::new(Let { + name: Identifier { + token_type: TokenType::Let, + value: "y".into() + }, + value: Rc::new(IntegerLiteral { value: 10 }) + }), + Rc::new(Let { + name: Identifier { + token_type: TokenType::Let, + value: "foobar".into() + }, + value: Rc::new(IntegerLiteral { value: 838383 }) + }) + ] + } + ); } #[test] @@ -455,7 +478,7 @@ mod tests { return 10;\ return 838383;\ " - .to_string(); + .into(); let lexer = Lexer::new(source); @@ -464,19 +487,27 @@ mod tests { let program = parser.parse().unwrap(); check_parser_errors(parser); - assert_eq!(program.statements.len(), 3); - - for stmt in program.statements { - assert_eq!( - format!("{stmt:?}"), - "Return { token: Return, value: DummyExpression }" - ) - } + assert_eq!( + program, + Program { + statements: vec![ + Rc::new(Return { + value: Rc::new(IntegerLiteral { value: 5 }) + }), + Rc::new(Return { + value: Rc::new(IntegerLiteral { value: 10 }) + }), + Rc::new(Return { + value: Rc::new(IntegerLiteral { value: 838383 }) + }) + ] + } + ); } #[test] fn identifier_expression() { - let source = "foobar;".to_owned(); + let source = "foobar;".into(); let lexer = Lexer::new(source); @@ -485,17 +516,42 @@ mod tests { let program = parser.parse().unwrap(); check_parser_errors(parser); - let expected_identifiers = vec!["foobar"]; - let mut statements_iter = program.statements.iter(); - for tt in expected_identifiers { - let statement = statements_iter.next().unwrap(); - test_identifier_expression(statement.clone(), tt); - } + assert_eq!( + program, + Program { + statements: vec![Rc::new(ExpressionStatement { + expression: Rc::new(Identifier { + token_type: TokenType::Ident, + value: "foobar".into() + }) + })] + } + ); } #[test] fn integer_literal_expression() { - let source = "6;".to_owned(); + let source = "6;".into(); + + let lexer = Lexer::new(source); + + let mut parser = Parser::new(lexer); + + let program = parser.parse().unwrap(); + check_parser_errors(parser); + assert_eq!( + program, + Program { + statements: vec![Rc::new(ExpressionStatement { + expression: Rc::new(IntegerLiteral { value: 6 }) + })] + } + ); + } + + #[test] + fn bool_expression() { + let source = "true; false".into(); let lexer = Lexer::new(source); @@ -504,12 +560,19 @@ mod tests { let program = parser.parse().unwrap(); check_parser_errors(parser); - let expected_integers = vec![6]; - let mut statements_iter = program.statements.iter(); - for tt in expected_integers { - let statement = statements_iter.next().unwrap(); - test_integer_literal_expression(statement.clone(), tt); - } + assert_eq!( + program, + Program { + statements: vec![ + Rc::new(ExpressionStatement { + expression: Rc::new(Boolean { value: true }) + }), + Rc::new(ExpressionStatement { + expression: Rc::new(Boolean { value: false }) + }) + ] + } + ); } #[test] @@ -523,13 +586,13 @@ mod tests { let tests = vec![ Test { - input: "!15".to_owned(), + input: "!15".into(), token: Token::Bang, operator: PrefixOperator::Bang, value: 15, }, Test { - input: "-15".to_owned(), + input: "-15".into(), token: Token::Minus, operator: PrefixOperator::Minus, value: 15, @@ -544,13 +607,16 @@ mod tests { let program = parser.parse().unwrap(); check_parser_errors(parser); - assert_eq!(program.statements.len(), 1); - - test_prefix_expression( - program.statements.first().unwrap().clone(), - test.token, - test.operator, - test.value, + assert_eq!( + program, + Program { + statements: vec![Rc::new(ExpressionStatement { + expression: Rc::new(Prefix { + operator: test.operator, + right: Rc::new(IntegerLiteral { value: test.value }) + }) + })] + } ); } } @@ -623,24 +689,25 @@ mod tests { let program = parser.parse().unwrap(); check_parser_errors(parser); - assert_eq!(program.statements.len(), 1); - - test_infix_expression( - program.statements.first().unwrap().clone(), - test.left_value, - test.operator, - test.right_value, + assert_eq!( + program, + Program { + statements: vec![Rc::new(ExpressionStatement { + expression: Rc::new(Infix { + operator: test.operator, + left: Rc::new(IntegerLiteral { + value: test.left_value + }), + right: Rc::new(IntegerLiteral { + value: test.right_value + }) + }) + })] + } ); } } - fn test_let_statement(stmt: Rc, name: &str) { - assert_eq!( - format!("{stmt:?}"), - format!("Let {{ token: Let, name: Identifier {{ token_type: Let, value: \"{name}\" }}, value: DummyExpression }}"), - ); - } - fn check_parser_errors(parser: Parser) { if parser.errors().len() == 0 { return; @@ -654,46 +721,4 @@ mod tests { panic!("{err}"); } - - fn test_identifier_expression(stmt: Rc, name: &str) { - assert_eq!( - format!("{stmt:?}"), - format!( - "ExpressionStatement {{ token: Ident(\"{name}\"), expression: Identifier {{ token_type: Ident, value: \"{name}\" }} }}" - ), - ); - } - - fn test_integer_literal_expression(stmt: Rc, num: i64) { - assert_eq!( - format!("{stmt:?}"), - format!( - "ExpressionStatement {{ token: Int({num}), expression: IntegerLiteral {{ value: {num} }} }}" - ), - ); - } - - fn test_prefix_expression( - stmt: Rc, - token: Token, - operator: PrefixOperator, - num: i64, - ) { - assert_eq!( - format!("{stmt:?}"), - format!("ExpressionStatement {{ token: {token:?}, expression: Prefix {{ operator: {operator:?}, right: IntegerLiteral {{ value: {num} }} }} }}"), - ); - } - - fn test_infix_expression( - stmt: Rc, - left: i64, - operator: InfixOperator, - right: i64, - ) { - assert_eq!( - format!("{stmt:?}"), - format!("ExpressionStatement {{ token: Int({left}), expression: Infix {{ operator: {operator:?}, left: IntegerLiteral {{ value: {left} }}, right: IntegerLiteral {{ value: {right} }} }} }}"), - ); - } } diff --git a/src/token.rs b/src/token.rs index ad6584d..f46ada6 100644 --- a/src/token.rs +++ b/src/token.rs @@ -1,6 +1,6 @@ use std::fmt::Display; -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] pub enum Token { Illegal, EOF, @@ -154,7 +154,7 @@ impl Token { "if" => If, "else" => Else, "return" => Return, - ident => Ident(ident.to_string()), + ident => Ident(ident.into()), } } }