Skip to content

Commit

Permalink
support pub attribute (#263)
Browse files Browse the repository at this point in the history
* support pub attribute for struct fields
  • Loading branch information
katat authored Jan 22, 2025
1 parent 4596f2b commit bb87a88
Show file tree
Hide file tree
Showing 22 changed files with 153 additions and 44 deletions.
2 changes: 1 addition & 1 deletion examples/assignment.no
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
struct Thing {
xx: Field,
pub xx: Field,
}

fn try_to_mutate(thing: Thing) {
Expand Down
4 changes: 2 additions & 2 deletions examples/hint.no
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
struct Thing {
xx: Field,
yy: Field,
pub xx: Field,
pub yy: Field,
}

hint fn mul(lhs: Field, rhs: Field) -> Field {
Expand Down
4 changes: 2 additions & 2 deletions examples/types.no
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
struct Thing {
xx: Field,
yy: Field,
pub xx: Field,
pub yy: Field,
}

fn main(pub xx: Field, pub yy: Field) {
Expand Down
4 changes: 2 additions & 2 deletions examples/types_array.no
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
struct Thing {
xx: Field,
yy: Field,
pub xx: Field,
pub yy: Field,
}

fn main(pub xx: Field, pub yy: Field) {
Expand Down
4 changes: 2 additions & 2 deletions examples/types_array_output.no
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
struct Thing {
xx: Field,
yy: Field,
pub xx: Field,
pub yy: Field,
}

fn main(pub xx: Field, pub yy: Field) -> [Thing; 2] {
Expand Down
2 changes: 1 addition & 1 deletion src/circuit_writer/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ impl<B: Backend> IRWriter<B> {
// find range of field
let mut start = 0;
let mut len = 0;
for (field, field_typ) in &struct_info.fields {
for (field, field_typ, _attribute) in &struct_info.fields {
if field == &rhs.value {
len = self.size_of(field_typ);
break;
Expand Down
4 changes: 2 additions & 2 deletions src/circuit_writer/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ impl<B: Backend> CircuitWriter<B> {
.clone();

let mut offset = 0;
for (_field_name, field_typ) in &struct_info.fields {
for (_field_name, field_typ, _attribute) in &struct_info.fields {
let len = self.size_of(field_typ);
let range = offset..(offset + len);
self.constrain_inputs_to_main(&input[range], field_typ, span)?;
Expand Down Expand Up @@ -501,7 +501,7 @@ impl<B: Backend> CircuitWriter<B> {
// find range of field
let mut start = 0;
let mut len = 0;
for (field, field_typ) in &struct_info.fields {
for (field, field_typ, _attribute) in &struct_info.fields {
if field == &rhs.value {
len = self.size_of(field_typ);
break;
Expand Down
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ pub enum ErrorKind {
#[error("division by zero")]
DivisionByZero,

#[error("cannot access private field `{1}` of struct `{0}` from outside its methods.")]
PrivateFieldAccess(String, String),

#[error("lhs `{0}` is less than rhs `{1}`")]
NegativeLhsLessThanRhs(String, String),

Expand Down
2 changes: 1 addition & 1 deletion src/inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ impl<B: Backend> CompiledCircuit<B> {

// parse each field
let mut res = vec![];
for (field_name, field_ty) in fields {
for (field_name, field_ty, _attribute) in fields {
let value = map.remove(field_name).ok_or_else(|| {
ParsingError::MissingStructFieldIdent(field_name.to_string())
})?;
Expand Down
6 changes: 3 additions & 3 deletions src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ impl<B: Backend> Mast<B> {

let mut sum = 0;

for (_, t) in &struct_info.fields {
for (_, t, _) in &struct_info.fields {
sum += self.size_of(t);
}

Expand Down Expand Up @@ -567,8 +567,8 @@ fn monomorphize_expr<B: Backend>(
let typ = struct_info
.fields
.iter()
.find(|(name, _)| name == &rhs.value)
.map(|(_, typ)| typ.clone());
.find(|(name, _, _)| name == &rhs.value)
.map(|(_, typ, _)| typ.clone());

let mexpr = expr.to_mast(
ctx,
Expand Down
2 changes: 1 addition & 1 deletion src/name_resolution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ impl NameResCtx {
self.resolve(module, true)?;

// we resolve the fully-qualified types of the fields
for (_field_name, field_typ) in fields {
for (_field_name, field_typ, _attribute) in fields {
self.resolve_typ_kind(&mut field_typ.kind)?;
}

Expand Down
25 changes: 23 additions & 2 deletions src/negative_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ fn test_nonhint_call_with_unsafe() {
fn test_no_cst_struct_field_prop() {
let code = r#"
struct Thing {
val: Field,
pub val: Field,
}
fn gen(const LEN: Field) -> [Field; LEN] {
Expand All @@ -748,7 +748,7 @@ fn test_no_cst_struct_field_prop() {
fn test_mut_cst_struct_field_prop() {
let code = r#"
struct Thing {
val: Field,
pub val: Field,
}
fn gen(const LEN: Field) -> [Field; LEN] {
Expand All @@ -770,3 +770,24 @@ fn test_mut_cst_struct_field_prop() {
ErrorKind::ArgumentTypeMismatch(..)
));
}

#[test]
fn test_private_field_access() {
let code = r#"
struct Room {
pub beds: Field, // public
size: Field // private
}
fn main(pub beds: Field) {
let room = Room {beds: beds, size: 10};
room.size = 5; // not allowed
}
"#;

let res = tast_pass(code).0;
assert!(matches!(
res.unwrap_err().kind,
ErrorKind::PrivateFieldAccess(..)
));
}
2 changes: 2 additions & 0 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ mod tests {
let parsed = StructDef::parse(ctx, tokens);
assert!(parsed.is_err());
assert!(parsed.as_ref().err().is_some());

println!("{:?}", parsed);
match &parsed.as_ref().err().unwrap().kind {
ErrorKind::ExpectedTokenNotKeyword(keyword, _) => {
assert_eq!(keyword, "pub");
Expand Down
38 changes: 34 additions & 4 deletions src/parser/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ use serde::{Deserialize, Serialize};
use crate::{
constants::Span,
error::{ErrorKind, Result},
lexer::{Token, TokenKind, Tokens},
lexer::{Keyword, Token, TokenKind, Tokens},
syntax::is_type,
};

use super::{
types::{Ident, ModulePath, Ty, TyKind},
types::{Attribute, AttributeKind, Ident, ModulePath, Ty, TyKind},
Error, ParserCtx,
};

Expand All @@ -17,7 +17,7 @@ pub struct StructDef {
//pub attribute: Attribute,
pub module: ModulePath, // name resolution
pub name: CustomType,
pub fields: Vec<(Ident, Ty)>,
pub fields: Vec<(Ident, Ty, Option<Attribute>)>,
pub span: Span,
}

Expand Down Expand Up @@ -55,6 +55,36 @@ impl StructDef {
tokens.bump(ctx);
break;
}

// check for pub keyword
// struct Foo { pub a: Field, b: Field }
// ^
let attribute = if matches!(
tokens.peek(),
Some(Token {
kind: TokenKind::Keyword(Keyword::Pub),
..
})
) {
let token = tokens.bump(ctx).unwrap();
// next token shouldn't be :
if tokens.peek().unwrap().kind == TokenKind::Colon {
return Err(ctx.error(
ErrorKind::ExpectedTokenNotKeyword(
"pub".to_string(),
TokenKind::Identifier("".to_string()),
),
token.span,
));
}
Some(Attribute {
kind: AttributeKind::Pub,
span: token.span,
})
} else {
None
};

// struct Foo { a: Field, b: Field }
// ^
let field_name = Ident::parse(ctx, tokens)?;
Expand All @@ -67,7 +97,7 @@ impl StructDef {
// ^^^^^
let field_ty = Ty::parse(ctx, tokens)?;
span = span.merge_with(field_ty.span);
fields.push((field_name, field_ty));
fields.push((field_name, field_ty, attribute));

// struct Foo { a: Field, b: Field }
// ^ ^
Expand Down
2 changes: 1 addition & 1 deletion src/stdlib/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ fn assert_eq_values<B: Backend>(

// compare each field recursively
let mut offset = 0;
for (_, field_type) in &struct_info.fields {
for (_, field_type, _) in &struct_info.fields {
let field_size = compiler.size_of(field_type);
let mut field_comparisons = assert_eq_values(
compiler,
Expand Down
20 changes: 19 additions & 1 deletion src/stdlib/native/int/lib.no
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,22 @@ fn Uint32.mod(self, rhs: Uint32) -> Uint32 {
fn Uint64.mod(self, rhs: Uint64) -> Uint64 {
let res = self.divmod(rhs);
return res[1];
}
}

// implement to field
fn Uint8.to_field(self) -> Field {
return self.inner;
}

fn Uint16.to_field(self) -> Field {
return self.inner;
}

fn Uint32.to_field(self) -> Field {
return self.inner;
}

fn Uint64.to_field(self) -> Field {
return self.inner;
}

2 changes: 1 addition & 1 deletion src/tests/modules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use mimoo::liblib;
// test a library's type that links to its own type
struct Inner {
inner: Field,
pub inner: Field,
}
struct Lib {
Expand Down
2 changes: 1 addition & 1 deletion src/tests/stdlib/uints/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn main(pub lhs: Field, rhs: Field) -> Field {
let res = lhs_u.{opr}(rhs_u);
return res.inner;
return res.to_field();
}
"#;

Expand Down
38 changes: 30 additions & 8 deletions src/type_checker/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use crate::{
imports::FnKind,
parser::{
types::{
is_numeric, FnSig, ForLoopArgument, FunctionDef, ModulePath, Stmt, StmtKind, Symbolic,
Ty, TyKind,
is_numeric, Attribute, AttributeKind, FnSig, ForLoopArgument, FuncOrMethod,
FunctionDef, ModulePath, Stmt, StmtKind, Symbolic, Ty, TyKind,
},
CustomType, Expr, ExprKind, Op2,
},
Expand Down Expand Up @@ -58,7 +58,7 @@ impl<B: Backend> FnInfo<B> {
#[derive(Deserialize, Serialize, Default, Debug, Clone)]
pub struct StructInfo {
pub name: String,
pub fields: Vec<(String, TyKind)>,
pub fields: Vec<(String, TyKind, Option<Attribute>)>,
pub methods: HashMap<String, FunctionDef>,
}

Expand Down Expand Up @@ -119,14 +119,36 @@ impl<B: Backend> TypeChecker<B> {
.expect("this struct is not defined, or you're trying to access a field of a struct defined in a third-party library (TODO: better error)");

// find field type
let res = struct_info
if let Some((_, field_typ, attribute)) = struct_info
.fields
.iter()
.find(|(name, _)| name == &rhs.value)
.map(|(_, typ)| typ.clone());
.find(|(field_name, _, _)| field_name == &rhs.value)
{
// check for the pub attribute
let is_public = matches!(
attribute,
&Some(Attribute {
kind: AttributeKind::Pub,
..
})
);

// check if we're inside a method of the same struct
let in_method = matches!(
typed_fn_env.current_fn_kind(),
FuncOrMethod::Method(method_struct) if method_struct.name == struct_name
);

if let Some(res) = res {
Some(ExprTyInfo::new(lhs_node.var_name, res))
if is_public || in_method {
// allow access
Some(ExprTyInfo::new(lhs_node.var_name, field_typ.clone()))
} else {
// block access
Err(self.error(
ErrorKind::PrivateFieldAccess(struct_name.clone(), rhs.value.clone()),
expr.span,
))?
}
} else {
return Err(self.error(
ErrorKind::UndefinedField(struct_info.name.clone(), rhs.value.clone()),
Expand Down
Loading

0 comments on commit bb87a88

Please sign in to comment.