Multithreading in Rust

To accomplish more work in a shorter amount of time is the general notion behind multithreading. This is accomplished by breaking up the function into different threads.

Then, these threads are run concurrently (maybe even in parallel).

Those threads must (mostly) be unrelated to one another. Code sections that rely on something from another thread must be handled with prudence.

Because of this, writing multithreaded code is frequently regarded as "extremely hard." Numerous multithreading issues are difficult to find because they are so subtle.

Fortunately for Rust programmers, concurrent programming efficiency and safety is one of the language's main objectives. Many of these bugs are impossible to write in Rust. Incorrect code will not compile and will display a warning message.

This article examines various approaches to writing multithreaded programmes by resolving a programming task.

Creating threads

By using the std::thread::spawn function, a thread is generated. This function accepts a closure.

The thread's closing statement contains the program's code.

That thread is "separated" from the thread that created it the instant it is created. Thus, it is entirely independent and capable of outlasting the thread from which it originated (unless that creator is the main thread, if that stops, everything stops).

Everything supplied to the closure must be valid throughout the entire programme (i.e., it has a "static lifetime") because it may "potentially outlive the parent thread." By doing this, the thread is made to remain legitimate even if the thread that produced it vanishes.

In actuality, this means that you want the closure to own each variable it makes use of. This is accomplished by using the move keyword in front of the closure's argument list.

A parent thread can be made to wait until a thread it generated has finished.

The result of calling std::thread::spawn is a JoinHandle. A join method on that handle prevents the running thread from continuing until the created thread has been closed.

use std::thread;
fn main() {
    let handle = thread::spawn(move || {
        // some work here
    });
    // some work here
    handle.join();
}

The problem

A frequency function must be created. It requires two parameters: a slice of strings and the number of workers.

A hashmap is the return value. The letters in those strings make up the keys, and each letter's value corresponds to how frequently it appears.

The worker_count number of threads must be used to complete this.

fn frequency(input: &[&str], worker_count: usize) -> HashMap<char, usize>

Single threaded

In this sample, a single thread-based approach to solving the issue is demonstrated.

fn frequency(input: &[&str]) -> HashMap<char, usize> {
    let mut map = HashMap::new();
    for line in input {
        for c in line.chars().filter(|c| c.is_alphabetic()) {
            *map.entry(c.to_ascii_lowercase()).or_default() += 1;
        }
    }
    map
}

Strategy

We'll break up the big issue into a number of smaller ones.

With code resembling that in the single-threaded example, each will be resolved on a thread.

The results of those lesser issues must be aggregated into one significant result, which will be the frequency function's reported value.

Calling the chunks method with the input parameter will divide up the larger problem.

input.chunks((input.len() / worker_count).max(1));

An iterator with the length worker count is the outcome. A single part of the problem will be solved by each thread.

Before passing a chunk of data into a thread, we ensure that we own it.

chunk.join("")

After that, we are prepared to create a thread to address the issue for each chunk.

pub fn frequency(input: &[&str], worker_count: usize) -> HashMap<char, usize> {
    // divide the large problem into smaller problems
    let chunks = input.chunks((input.len() / worker_count).max(1));
    for chunk in chunks {
        // collect the data for the current chunk into an owned variable before sending it to the thread.
        let string = chunk.join("");
        thread::spawn(move || {
            // solve the problem for the current chunk
        });
    }
    // combine the solutions
}

Joinhandle

An inner type may exist in a JoinHandle. As a result, when a child thread calls join, the parent thread can access whatever was returned by the child thread.

use std::collections::HashMap;
use std::thread;
pub fn frequency(input: &[&str], worker_count: usize) -> HashMap<char, usize> {
    let mut result: HashMap<char, usize> = HashMap::new();
    let chunks = input.chunks((input.len() / worker_count).max(1));
    let mut handles = Vec::new();
    for chunk in chunks {
        let string = chunk.join("");
        // return a HashMap from each thread, the JoinHandle wraps this hashmap
        let handle = thread::spawn(move || {
            let mut map: HashMap<char, usize> = HashMap::new();
            for c in string.chars().filter(|c| c.is_alphabetic()) {
                *map.entry(c.to_ascii_lowercase()).or_default() += 1;
            }
            map
        });
        handles.push(handle);
    }
    // wait for each thread to finish and combine every HashMap into the final result
    for handle in handles {
        let map = handle.join().unwrap();
        for (key, value) in map {
            *result.entry(key).or_default() += value;
        }
    }
    result
}

Channel

Message forwarding is a common technique to provide secure concurrency. Messages carrying data are sent back and forth between multiple threads to communicate.

Rust has the tool of a channel for this.

A channel is comparable to a stream of water. Put something in one end, it comes out at the other end.

A transmitter and a receiver make up the two sides of a programming channel. Place items in the sender and remove items from the receiver.

"Multiple producer, single consumer" refers to an implementation of this in the Rust standard library. Therefore, there can be several senders on the mpsc channel but only one recipient.

This is comparable to a river delta with numerous smaller rivers coming to a common terminus.

The tuple of a sender and a receiver is returned by the std::sync::mpsc::channel method.

The Sender can be duplicated to provide numerous copies that can be sent between threads.

The send method of a Sender, as you would have guessed, transmits a value down the channel.

A sent value can no longer be used on the thread it was sent from; it must be an owned value. When the receiver receives that value, ownership passes to the recipient.

A channel is a single ownership concept, if you will.

The recv method of the Receiver stops the running thread until a message is received.

use std::thread;
use std::sync::mpsc;
fn main() {
    let (sender, receiver) = mpsc::channel();
    for i in 0..10 {
        let sender = sender.clone();
        thread::spawn(move|| {
            sender.send(i).unwrap();
        });
    }
    for _ in 0..10 {
        // receive each value and wait between each one
        println!("Got: {}", receiver.recv().unwrap());
    }
}

If either all of the senders or the lone receiver are dropped, the channel ends.

Over the receiver, iterations are possible. When the iterator asks for the next value, the receiver will block. The iterator will return None and terminate when the channel is closed.

If we want to send cloned Senders into a thread, this poses a minor issue. The channel will always be open and never lose the original Sender.

In the previous example, where we looped a certain amount of times, this wasn't a problem, but if we use the iterator technique, we end up with an infinite wait.

use std::thread;
use std::sync::mpsc;
fn main() {
    let (sender, receiver) = mpsc::channel();
    for i in 0..10 {
        let sender = sender.clone();
        thread::spawn(move|| {
            sender.send(i).unwrap();
        });
    }
    // this will wait until all senders are dropped
    // the original sender is never dropped, so this waits forever
    for received in receiver {
        println!("Got: {}", received);
    }
}

The solution is to drop the original sender.

use std::mem;
use std::thread;
use std::sync::mpsc;
fn main() {
    let (sender, receiver) = mpsc::channel();
    for i in 0..10 {
        let sender = sender.clone();
        thread::spawn(move|| {
            sender.send(i).unwrap();
        });
    }
    // drop the original sender
    mem::drop(sender);
    for received in rx {
        println!("Got: {}", received);
    }
}

Curly braces can also be used to surround the entire top part in a scope, ensuring that everything is out of scope when the iterator is called.

Applied

Going back to our frequency issue

We convert the data into a String and transmit it into a thread.

This time, each thread receives a sender as well.

The HashMap is sent via the channel whenever that thread has completed solving its chunk.

Then, all messages are picked up by the main thread, which merges them into the final product.

use std::collections::HashMap;
use std::mem;
use std::sync::mpsc;
use std::thread;
pub fn frequency(input: &[&str], worker_count: usize) -> HashMap<char, usize> {
    let mut result: HashMap<char, usize> = HashMap::new();
    let chunks = input.chunks((input.len() / worker_count).max(1));
    let (sender, receiver) = mpsc::channel();
    for chunk in chunks {
        let sender = sender.clone();
        let string = chunk.join("");
        thread::spawn(move || {
            // Solve each chunk and send the resulting HashMap down the channel
            let mut map: HashMap<char, usize> = HashMap::new();
            for c in string.chars().filter(|c| c.is_alphabetic()) {
                *map.entry(c.to_ascii_lowercase()).or_default() += 1;
            }
            sender.send(map).unwrap();
        });
    }
    // drop the original sender, else the channel will remain open, causing the receiver to infinitely wait
    mem::drop(sender);
    // combine every received HashMap
    for received in receiver {
        for (key, value) in received {
            *result.entry(key).or_default() += value;
        }
    }
    result
}

Mutex

Other data is surrounded by a Mutex. The mutex ensures that only one thread at a time can access the inner data by requiring all requests to pass through it first.

The name is also derived from mutual exclusion.

Although a Mutex can be used in a single-threaded setting. Using one when you can securely access data directly would add an unneeded layer of complexity.

The mutex will be passed into a thread, hence it is frequently wrapped in an Arc so that it can be owned by several threads simultaneously.

A mutex assures that only one thread can access the same piece of data at once while allowing several threads to access (and modify) the same piece of data.

A mutex is similar to a multiple ownership construct.

The lock method of a mutex returns a MutexGuard if it is successful. This "locks" the mutex, prohibiting access from any other thread.

While the guard is still in place, if another thread tries to access the mutex, it will be blocked until the lock can be obtained.

A wise hint, that MutexGuard.

That smart pointer provides access to the internal data. The lock is dropped when the MutexGuard exits its range, giving another thread a chance to pick it up.

The next example creates ten threads, each of which increases the value in the mutex. Although the order in which the threads execute is unpredictable, the final count will always be 10.

use std::sync::{Arc, Mutex};
use std::thread;
fn main() {
    let mutex = Arc::new(Mutex::new(0));
    let mut handles = Vec::new();
    for _ in 0..10 {
        let mutex = Arc::clone(&mutex);
        let handle = thread::spawn(move || {
            let mut guard = mutex.lock().unwrap();
            *guard += 1;
        });
        handles.push(handle);
    }
    for handle in handles {
        handle.join().unwrap();
    }
    assert_eq!(*mutex.lock().unwrap(), 10);
}

Applied

Going back to our frequency issue

We convert the data into a String and transmit it into a thread.

This time, we include a mutex in each thread that we send.

The resultant HashMap is added to that mutex once that thread has finished solving its chunk.

The main thread returns the data that the mutex wraps after waiting for all other threads to complete.

use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::thread;
pub fn frequency(input: &[&str], worker_count: usize) -> HashMap<char, usize> {
    let result = Arc::new(Mutex::new(HashMap::new()));
    let chunks = input.chunks((input.len() / worker_count).max(1));
    let mut handles: Vec<_> = Vec::new();
    for chunk in chunks {
        let string = chunk.join("");
        let result = Arc::clone(&result);
        let handle = thread::spawn(move || {
            let mut map: HashMap<char, usize> = HashMap::new();
            // create a HashMap for this chunk
            for c in string.chars().filter(|c| c.is_alphabetic()) {
                *map.entry(c.to_ascii_lowercase()).or_default() += 1;
            }
            // add the HashMap of this chunk to the HashMap that is wrapped by the Mutex
            let mut result = result.lock().unwrap();
            for (key, value) in map {
                *result.entry(key).or_default() += value;
            }
        });
        handles.push(handle);
    }
    // wait for each thread to finish
    for handle in handles {
        handle.join().unwrap()
    }
    // get the HashMap from the Arc<Mutex<HashMap>>
    Arc::try_unwrap(result).unwrap().into_inner().unwrap()
}

Bonus: gotchas

It is important to make sure that any concurrent programming you design can benefit the most from being concurrent. As long as the code is accurate, the Rust compiler will happily allow you to compile code that runs more slowly than its sequential counterpart.

It affects if mutexes are locked or where the threads are connected. Ask yourself whether there is any more work you can perform each time you block a thread and make it wait before continuing.

Sometimes this entails developing code in a new approach, which at first glance appears to be less efficient but is quicker because it waits less.

For instance, in the mutex illustration. Each thread's code consists of two for loops and reads something like this:

for loop
lock
for loop

If we lock the mutex before that loop, the same thing might be accomplished with just one for loop.

The calculation would basically become sequential once more with this form. Every thread locks the mutex and prevents the access of any others.

Each thread must wait for its turn because the work is completed once the lock has been obtained.

The task specific to each thread is completed without holding up other threads by splitting the thread's code into two sections.

Iterators have a lot of appeal to me.

One item at a time is processed by an iterator chain.

This implies that for the subsequent iterator, each item in the chain is executed completely before moving on to the subsequent item:

// first executes the entire chain for 1, then 2, then 3
[1, 2, 3].iter().filter(|n| n % 2 != 0).map(|n| n * 2).sum()

In a previous version of my JoinHandle code, everything was one big iterator chain.

Inside that chain I called handle.join(). That meant every other thread couldn’t even be spawned before the previous one finished running.

Final code

JoinHandle

use std::collections::HashMap;
use std::thread;
pub fn frequency(input: &[&str], worker_count: usize) -> HashMap<char, usize> {
    let mut result: HashMap<char, usize> = HashMap::new();
    let chunks = input.chunks((input.len() / worker_count).max(1));
    let mut handles = Vec::new();
    for chunk in chunks {
        let string = chunk.join("");
        // return a HashMap from each thread, the JoinHandle wraps this hashmap
        let handle = thread::spawn(move || {
            let mut map: HashMap<char, usize> = HashMap::new();
            for c in string.chars().filter(|c| c.is_alphabetic()) {
                *map.entry(c.to_ascii_lowercase()).or_default() += 1;
            }
            map
        });
        handles.push(handle);
    }
    // wait for each thread to finish and combine every HashMap into the final result
    for handle in handles {
        let map = handle.join().unwrap();
        for (key, value) in map {
            *result.entry(key).or_default() += value;
        }
    }
    result
}

Channel

use std::collections::HashMap;
use std::mem;
use std::sync::mpsc;
use std::thread;

pub fn frequency(input: &[&str], worker_count: usize) -> HashMap<char, usize> {
    let mut result: HashMap<char, usize> = HashMap::new();
    let chunks = input.chunks((input.len() / worker_count).max(1));
    let (sender, receiver) = mpsc::channel();
    for chunk in chunks {
        let sender = sender.clone();
        let string = chunk.join("");
        thread::spawn(move || {
            // Solve each chunk and send the resulting HashMap down the channel
            let mut map: HashMap<char, usize> = HashMap::new();
            for c in string.chars().filter(|c| c.is_alphabetic()) {
                *map.entry(c.to_ascii_lowercase()).or_default() += 1;
            }
            sender.send(map).unwrap();
        });
    }
    // drop the original sender, else the channel will remain open, causing the receiver to infinitely wait
    mem::drop(sender);
    // combine every received HashMap
    for received in receiver {
        for (key, value) in received {
            *result.entry(key).or_default() += value;
        }
    }

    result
}

Mutex

use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::thread;

pub fn frequency(input: &[&str], worker_count: usize) -> HashMap<char, usize> {
    let result = Arc::new(Mutex::new(HashMap::new()));
    let chunks = input.chunks((input.len() / worker_count).max(1));
    let mut handles: Vec<_> = Vec::new();
    for chunk in chunks {
        let string = chunk.join("");
        let result = Arc::clone(&result);
        let handle = thread::spawn(move || {
            let mut map: HashMap<char, usize> = HashMap::new();
            // create a HashMap for this chunk
            for c in string.chars().filter(|c| c.is_alphabetic()) {
                *map.entry(c.to_ascii_lowercase()).or_default() += 1;
            }
            // add the HashMap of this chunk to the HashMap that is wrapped by the Mutex
            let mut result = result.lock().unwrap();
            for (key, value) in map {
                *result.entry(key).or_default() += value;
            }
        });
        handles.push(handle);
    }
    // wait for each thread to finish
    for handle in handles {
        handle.join().unwrap()
    }
    // get the HashMap from the Arc<Mutex<HashMap>>
    Arc::try_unwrap(result).unwrap().into_inner().unwrap()
}

Post a Comment

0 Comments