Misuse Resistant Code
... a programming language designer should be responsible for the mistakes that are made by the programmers
- Tony Hoare
# Your crypto API
def encrypt(msg: bytes, key: bytes, nonce: bytes) -> bytes:
pass # impl
# Consumer code
def main() -> None:
nonce = make_nonce() # cryptographic nonce
key = get_key() # load up the key
encrypt(b"message A", key, nonce) # encrypt, hooray
encrypt(b"message B", key, nonce) # nonce reuse :(
class Nonce:
def __init__() -> None:
self.__nonce: Optional[bytes] = make_nonce()
def get_nonce() -> bytes:
if self.__nonce is None:
raise ValueError("Attepted to reuse nonce! This is a bug!")
nonce = self.__nonce
self.__nonce = Nonce
return nonce
def encrypt(msg: bytes, key: bytes, nonce: Nonce) -> bytes:
# in our crypto library we can ensure that
# we're very careful with the raw value
raw_nonce = nonce.get_nonce()
raise NotImplemented
def main() -> None:
nonce = Nonce() # cryptographic nonce
key = get_key() # load up the key
encrypt(b"message A", key, nonce) # encrypt, hooray
# Runtime error! Better than a crypto failure
encrypt(b"message B", key, nonce)
Affine Types (a little rust)
pub struct Nonce {
inner_nonce: Vec<u8>,
}
impl Nonce {
pub fn new() -> Self {
return Nonce { inner_nonce: unimplemented!() }
}
# private method - can't be accessed outside of our crypto module!
fn unchecked_get(self) -> Vec<u8> { return self.inner_nonce }
}
# Note that `nonce` is not passed in by value, it is *moved* into
# this function - it can't be accessed afterwards by the caller
pub fn encrypt(msg: &[u8], key: &[u8], nonce: Nonce) -> Vec<u8> {
let nonce = nonce.unchecked_get(); # Be careful with the raw value!
unimplemented!()
}
fn main() {
let nonce = Nonce::new();
let key = &b"very safe key!"[..];
encrypt(&b"Message A", &key, nonce); // Encrypts, yay
encrypt(&b"Message B", &key, nonce); // Compile time error!
}
Refinement Types
# Implicit constraint: the user has a valid session
def disable_mfa(username: str) -> None:
pass
def handle_request(request: Request) -> Response:
# We forgot to check if the user is actually logged in
if request.path == "/disable_mfa":
disable_mfa(request.username)
Refinement Types
class User:
@staticmethod
def from_request(request: Request) -> User:
...
def as_logged_in_user(self) -> Optional[LoggedInUser]:
return LoggedInUser.try_from_user(self)
class LoggedInUser:
@staticmethod
def try_from_user(user: User) -> Optional[LoggedInUser]:
if !self.is_logged_in(): return None
... # construct
# Explicit constraint: the user has a valid session
def disable_mfa(user: LoggedInUser) -> None:
pass
def handle_request(request: Request) -> Response:
# We forgot to check if the user is actually logged in
if request.path == "/disable_mfa":
user = User.from_request(request)
if logged_in_user := user.as_logged_in_user():
disable_mfa(logged_in_user)
else:
return 403
Session Types: Affine + Refinement
+----------------------+
|connection established|
+----------------------+
||
\/
+--------------------------------------+
| server greeting |
+--------------------------------------+
|| (1) || (2) || (3)
\/ || ||
+-----------------+ || ||
|Not Authenticated| || ||
+-----------------+ || ||
|| (7) || (4) || ||
|| \/ \/ ||
|| +----------------+ ||
|| | Authenticated |<=++ ||
|| +----------------+ || ||
|| || (7) || (5) || (6) ||
|| || \/ || ||
|| || +--------+ || ||
|| || |Selected|==++ ||
|| || +--------+ ||
|| || || (7) ||
\/ \/ \/ \/
+--------------------------------------+
| Logout |
+--------------------------------------+
||
\/
+-------------------------------+
|both sides close the connection|
+-------------------------------+
Session Types: Affine + Refinement
struct SomeState {}
struct NewState {}
impl SomeState { // ↱ Our successful or unsuccessful state transition
fn transition(self) -> Result<NewState, SomeState> { ... }
// ↳consumes self
}
Session Types: Affine + Refinement
struct IMAPClient {}
impl IMAPClient {
fn new() -> Self {unimplemented!()}
fn connect(&mut self) -> Result<(), Error> {unimplemented!()}
fn login(&mut self) -> Result<(), Error> {unimplemented!()}
fn select(&mut self) -> Result<(), Error> {unimplemented!()}
fn authenticate(&mut self) -> Result<(), Error> {unimplemented!()}
fn logout(&mut self) -> Result<(), Error> {unimplemented!()}
}
fn main() {
let mut client = IMAPClient::new();
client.select(); // this makes no sense - we aren't authenticated!
}
Session Types: Affine + Refinement
struct IMAPClient {}
struct Initial {
connection: IMAPClient
}
struct UnAuthenticated {
connection: IMAPClient
}
struct Authenticated {
connection: IMAPClient
}
struct Selected {
connection: IMAPClient
}
struct LoggedOut {
connection: IMAPClient
}
Session Types: Affine + Refinement
impl Initial {
fn connect(self) -> Result<UnAuthenticated, (Initial, Error)> {
unimplemented!()
}
}
impl UnAuthenticated {
fn login(self) -> Result<Authenticated, (UnAuthenticated, Error)> {
unimplemented!()
}
fn logout(self) -> Result<Logout, (UnAuthenticated, Error)> {
unimplemented!()
}
}
Session Types: Affine + Refinement
use imap::{Initial};
fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = Initial::init();
let client = match client.connect() {
Ok(client) => client,
Err((_client, error)) => {
panic!("failed to connect with {:#?}", error);
}
};
let client = match client.login() {
Ok(client) => client,
Err((_client, error)) => {
panic!("failed to login with {:#?}", error);
}
};
// etc
Ok(())
}
Dependent Types
class Matrix:
def __init__(self, x, y) -> None:
self.x = x
self.y = y
self.values = [0] * x * y
def multiply(self, other: Matrix) -> Matrix:
if self.x != other.y:
raise ValueException("Matrix multiplication invariant violated")
...
def main():
matrix_a = Matrix(2, 4)
matrix_b = Matrix(3, 7)
matrix_c = matrix_a.multiple(matrix_b) # runtime error
Dependent Types
class Matrix[X: int, Y: int]:
def __init__(self, x, y) -> None:
self.x = x
self.y = y
self.values = [0] * x * y
def multiply[N: int](self, other: Matrix[Y, N]) -> Matrix[X, N]:
...
def main():
matrix_a = Matrix(2, 4)
matrix_b = Matrix(4, 7)
# Compile time checks that the matrices can be multiplied
matrix_d: Matrix[2, 7] = matrix_a.multiple(matrix_b)
Dependent Types
class Matrix[X: int, Y: int]:
def __init__(self, x, y) -> None:
self.x = x
self.y = y
self.values = [0] * x * y
def append_rows[N: int](self, other: Matrix[N, Y]) -> Matrix[X + N, Y]:
...
def main():
matrix_a = Matrix(2, 5)
matrix_b = Matrix(4, 5)
matrix_d: Matrix[6, 5] = matrix_a.append_rows(matrix_b)
Dependent Types + Type Narrowing
class Matrix[X: int, Y: int]:
def __init__(self, x, y) -> None:
self.x = x
self.y = y
self.values = [0] * x * y
def append_rows[N: int](self, other: Matrix[N, Y]) -> Matrix[X + N, Y]:
...
def check_for_append[A: int, B: int](self, other: Matrix[A, B]) -> Optional[Matrix[A, Y]]:
if self.y == other.y:
return other
else:
return None
def main():
matrix_a = Matrix(2, 5)
matrix_b: Matrix[_, _] = load_matrix_from_disk()
if matrix_c := matrix_a.check_for_append(matrix_b):
matrix_c: Matrix[_, 5]; # We know this type checks now
matrix_d: Matrix[_ + 6, 5] = matrix_a.append_rows(matrix_b)
Dependent Types + Flow Typing
class Matrix[X: int, Y: int]:
def __init__(self, x, y) -> None:
self.x = x
self.y = y
self.values = [0] * x * y
def append_rows[N: int](self, other: Matrix[N, Y]) -> Matrix[X + N, Y]:
...
def main():
matrix_a = Matrix(2, 5)
matrix_b: Matrix[_, _] = load_matrix_from_disk()
if matrix_b.y == matrix_a.y: # Flow Typing
matrix_d: Matrix[_ + 6, 5] = matrix_a.append_rows(matrix_b)
Questions
Misuse Resistant Code
By Colin
Misuse Resistant Code
- 317