Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Multiprocess execution #82

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
11 changes: 9 additions & 2 deletions commpy/examples/wifi80211_conv_encode_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# License: BSD 3-Clause

import math
import time
from datetime import timedelta

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -22,8 +24,13 @@
SNRs2 = np.arange(0, 6) + 10 * math.log10(w2.get_modem().num_bits_symbol)
SNRs3 = np.arange(0, 6) + 10 * math.log10(w3.get_modem().num_bits_symbol)

BERs_mcs2 = w2.link_performance(channels, SNRs2, 10, 10, 600, stop_on_surpass_error=False)
BERs_mcs3 = w3.link_performance(channels, SNRs3, 10, 10, 600, stop_on_surpass_error=False)

start = time.time()
BERs_mcs2 = w2.link_performance(channels, SNRs2, 10, 10, 600, stop_on_surpass_error=False)[0]
BERs_mcs3 = w3.link_performance(channels, SNRs3, 10, 10, 600, stop_on_surpass_error=False)[0]
print(BERs_mcs2)
print(BERs_mcs3)
print(str(timedelta(seconds=(time.time() - start))))

# Test
plt.semilogy(SNRs2, BERs_mcs2, 'o-', SNRs3, BERs_mcs3, 'o-')
Expand Down
38 changes: 38 additions & 0 deletions commpy/examples/wifi80211_conv_encode_decode_multiprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Authors: CommPy contributors
# License: BSD 3-Clause

import math
import time
from datetime import timedelta

import matplotlib.pyplot as plt
import numpy as np

import commpy.channels as chan
# ==================================================================================================
# Complete example using Commpy Wifi 802.11 physical parameters
# ==================================================================================================
from commpy.multiprocess_links import Wifi80211

# AWGN channel
channel = chan.SISOFlatChannel(None, (1 + 0j, 0j))

w2 = Wifi80211(mcs=2)
w3 = Wifi80211(mcs=3)

# SNR range to test
SNRs2 = np.arange(0, 6) + 10 * math.log10(w2.get_modem().num_bits_symbol)
SNRs3 = np.arange(0, 6) + 10 * math.log10(w3.get_modem().num_bits_symbol)


start = time.time()
BERs = w2.link_performance_mp_mcs([2, 3], [SNRs2, SNRs3], channel, 10, 10, 600, stop_on_surpass_error=False)
print(BERs)
print(str(timedelta(seconds=(time.time() - start))))
# Test
plt.semilogy(SNRs2, BERs[2][0], 'o-', SNRs3, BERs[3][0], 'o-')
plt.grid()
plt.xlabel('Signal to Noise Ration (dB)')
plt.ylabel('Bit Error Rate')
plt.legend(('MCS 2', 'MCS 3'))
plt.show()
213 changes: 213 additions & 0 deletions commpy/multiprocess_links.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Authors: CommPy contributors
# License: BSD 3-Clause

"""
============================================
Multiprocess Links (:mod:`commpy.links`)
============================================

.. autosummary::
:toctree: generated/

LinkModel -- Multiprocess Link model object.
Wifi80211 -- Multiprocess class to simulate the transmissions and receiving parameters of physical layer 802.11

"""
from __future__ import division # Python 2 compatibility

from itertools import product, cycle
from multiprocessing import Pool
from typing import Iterable, List, Optional

import numpy as np

from commpy.channels import _FlatChannel
from commpy.links import LinkModel as SPLinkModel
from commpy.wifi80211 import Wifi80211 as SPWifi80211

__all__ = ['LinkModel', 'Wifi80211']


class LinkModel(SPLinkModel):

def __init__(self, modulate, channel, receive, num_bits_symbol, constellation, Es=1., decoder=None, rate=1.,
number_of_process: int = -1):
self.params_builder = _RunParamsBuilder(modulate, channel, receive, num_bits_symbol, constellation, Es, decoder,
rate)
self.full_simulation_results = []
self.number_of_process = number_of_process

def link_performance_full_metrics(self, SNRs: Iterable, tx_max, err_min, send_chunk=None, code_rate: float = 1.,
number_chunks_per_send=1, stop_on_surpass_error=True):
pool = Pool(self.number_of_process if self.number_of_process > 0 else None)
results = pool.map(_run_link_performance_full_metrics,
[self.params_builder.build_to_run([SNR],
tx_max, err_min, send_chunk,
code_rate, number_chunks_per_send,
stop_on_surpass_error)
for SNR in SNRs])
tmp_res = {}
for SNR, res in results:
tmp_res[SNR] = res
tmp_res_keys = sorted(tmp_res.keys())
self.full_simulation_results = [[], [], [], []]
for SNR in tmp_res_keys:
BERs, BEs, CEs, NCs = tmp_res[SNR]
self.full_simulation_results[0].append(BERs)
self.full_simulation_results[1].append(BEs)
self.full_simulation_results[2].append(CEs)
self.full_simulation_results[3].append(NCs)

return self.full_simulation_results

def link_performance(self, SNRs, send_max, err_min, send_chunk=None, code_rate=1):
pool = Pool(self.number_of_process if self.number_of_process > 0 else None)
results = pool.map(_run_link_performance,
[self.params_builder.build_to_run([SNR],
send_max, err_min, send_chunk,
code_rate)
for SNR in SNRs])
tmp_res = {}
for SNR, BERs in results:
tmp_res[SNR] = BERs
tmp_res_keys = sorted(tmp_res.keys())
self.full_simulation_results = []
for SNR in tmp_res_keys:
self.full_simulation_results.extend(tmp_res[SNR])
return self.full_simulation_results


class _RunParamsBuilder:
def __init__(self, modulate, channel, receive, num_bits_symbol, constellation, Es, decoder, rate):
self.modulate = modulate
self.channel = channel
self.receive = receive
self.num_bits_symbol = num_bits_symbol
self.constellation = constellation
self.Es = Es
self.rate = rate
self.decoder = decoder

def build_to_run(self, SNR, tx_max, err_min, send_chunk, code_rate,
number_chunks_per_send=1, stop_on_surpass_error=True):
return _RunParams(self.modulate,
self.channel,
self.receive,
self.num_bits_symbol,
self.constellation,
self.Es,
self.decoder,
self.rate,
SNR, tx_max, err_min, send_chunk, code_rate,
number_chunks_per_send, stop_on_surpass_error
)


class _RunParams:
def __init__(self, modulate, channel, receive, num_bits_symbol, constellation, Es, decoder, rate,
SNRs, tx_max, err_min, send_chunk, code_rate,
number_chunks_per_send, stop_on_surpass_error
):
self.modulate = modulate
self.channel = channel
self.receive = receive
self.num_bits_symbol = num_bits_symbol
self.constellation = constellation
self.Es = Es
self.rate = rate
self.decoder = decoder
self.SNRs = SNRs
self.tx_max = tx_max
self.err_min = err_min
self.send_chunk = send_chunk
self.code_rate = code_rate
self.number_chunks_per_send = number_chunks_per_send
self.stop_on_surpass_error = stop_on_surpass_error


def _run_link_performance_full_metrics(run_params: _RunParams):
link_model = SPLinkModel(run_params.modulate, run_params.channel, run_params.receive, run_params.num_bits_symbol,
run_params.constellation, run_params.Es, run_params.decoder, run_params.rate)
return run_params.SNRs[0], [x[0] for x in
link_model.link_performance_full_metrics(run_params.SNRs, run_params.tx_max,
run_params.err_min,
run_params.send_chunk, run_params.code_rate,
run_params.number_chunks_per_send,
run_params.stop_on_surpass_error)]


def _run_link_performance(run_params: _RunParams):
link_model = SPLinkModel(run_params.modulate, run_params.channel, run_params.receive, run_params.num_bits_symbol,
run_params.constellation, run_params.Es, run_params.decoder, run_params.rate)
return run_params.SNRs[0], [x[0] if isinstance(x, np.ndarray) else x for x in
link_model.link_performance(run_params.SNRs, run_params.tx_max,
run_params.err_min,
run_params.send_chunk, run_params.code_rate)]


class Wifi80211(SPWifi80211):
def __init__(self, mcs: int, number_of_processes=-1):
self.mcs = mcs
self.number_of_processes = number_of_processes

def link_performance(self, channel: _FlatChannel, SNRs: Iterable, tx_max, err_min, send_chunk=None,
frame_aggregation=1, receiver=None, stop_on_surpass_error=True):
return self.link_performance_mp_mcs([self.mcs], [SNRs], channel, tx_max, err_min, send_chunk, frame_aggregation,
[receiver], stop_on_surpass_error)[self.mcs]

def link_performance_mp_mcs(self, mcss: List[int], SNRss: Iterable[Iterable],
channel: _FlatChannel, tx_max, err_min, send_chunk=None,
frame_aggregation=1,
receivers: Optional[Iterable] = None,
stop_on_surpass_error=True):
"""
Explicit multiprocess of multiple MCSs link performance call

Parameters
----------
mcss : list of MCSs to run
SNRss : SNRs to test
channel : Channel to test the MCSs at each SNR at
tx_max : maximum number of transmissions to test
err_min : minimum error be
send_chunk : amount of bits to send at each frame
frame_aggregation : number of frames to send at each transmission
receivers : function to handle receiving
stop_on_surpass_error : flag to stop when err_min was surpassed

Returns
-------

"""
pool = Pool(self.number_of_processes if self.number_of_processes > 0 else None)

if not receivers:
receivers = [None] * len(mcss)

results = pool.map(_run_wifi80211_link_performance,
[[[SNR], mcs, channel, tx_max, err_min, send_chunk, frame_aggregation,
receiver, stop_on_surpass_error]
for _SNRs, mcs, receiver in zip(SNRss, mcss, receivers)
for SNR in _SNRs])
tmp_res = {}
for SNR, mcs, res in results:
tmp_res.setdefault(mcs, {})[SNR] = res
tmp_res_keys = sorted(tmp_res.keys())
full_simulation_results = {}
for mcs in tmp_res_keys:
full_simulation_results[mcs] = [[], [], [], []]
for snr in sorted(tmp_res[mcs].keys()):
BERs, BEs, CEs, NCs = tmp_res[mcs][snr]
full_simulation_results[mcs][0].append(BERs[0])
full_simulation_results[mcs][1].append(BEs)
full_simulation_results[mcs][2].append(CEs)
full_simulation_results[mcs][3].append(NCs)
return full_simulation_results


def _run_wifi80211_link_performance(args: List):
SNRs, mcs, channel, tx_max, err_min, send_chunk, frame_aggregation, receiver, stop_on_surpass_error = args
sp_wifi80211 = SPWifi80211(mcs)
res = sp_wifi80211.link_performance(channel, SNRs, tx_max, err_min, send_chunk, frame_aggregation,
receiver, stop_on_surpass_error)
return SNRs[0], mcs, res
101 changes: 101 additions & 0 deletions commpy/tests/test_multiprocess_links.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Authors: CommPy contributors
# License: BSD 3-Clause

from __future__ import division # Python 2 compatibility

from numpy import arange, sqrt, log10
from numpy.random import seed
from numpy.testing import run_module_suite, assert_allclose, dec
from scipy.special import erfc

from commpy.channels import MIMOFlatChannel, SISOFlatChannel
from commpy.modulation import QAMModem, kbest
from commpy.multiprocess_links import LinkModel, Wifi80211

# from commpy.tests.test_multiprocess_links_support import QPSK, receiver

QPSK = QAMModem(4)


def receiver(y, h, constellation, noise_var):
return QPSK.demodulate(y, 'hard')


QAM16 = QAMModem(16)


def receiver16(y, h, constellation, noise_var):
return QAM16.demodulate(kbest(y, h, constellation, 16), 'hard')


@dec.slow
def test_link_performance():
# Set seed
seed(17121996)

# Apply link_performance to SISO QPSK and AWGN channel

model = LinkModel(QPSK.modulate, SISOFlatChannel(fading_param=(1 + 0j, 0)), receiver,
QPSK.num_bits_symbol, QPSK.constellation, QPSK.Es)

BERs = model.link_performance(range(0, 9, 2), 600e4, 600)
desired = erfc(sqrt(10 ** (arange(0, 9, 2) / 10) / 2)) / 2
assert_allclose(BERs, desired, rtol=0.25,
err_msg='Wrong performance for SISO QPSK and AWGN channel')
full_metrics = model.link_performance_full_metrics(range(0, 9, 2), 1000, 600)
assert_allclose(full_metrics[0], desired, rtol=0.25,
err_msg='Wrong performance for SISO QPSK and AWGN channel')

# Apply link_performance to MIMO 16QAM and 4x4 Rayleigh channel
RayleighChannel = MIMOFlatChannel(4, 4)
RayleighChannel.uncorr_rayleigh_fading(complex)

model = LinkModel(QAM16.modulate, RayleighChannel, receiver16,
QAM16.num_bits_symbol, QAM16.constellation, QAM16.Es)
SNRs = arange(0, 21, 5) + 10 * log10(QAM16.num_bits_symbol)

BERs = model.link_performance(SNRs, 600e4, 600)
desired = (2e-1, 1e-1, 3e-2, 2e-3, 4e-5) # From reference
assert_allclose(BERs, desired, rtol=1.25,
err_msg='Wrong performance for MIMO 16QAM and 4x4 Rayleigh channel')
full_metrics = model.link_performance_full_metrics(SNRs, 1000, 600)
assert_allclose(full_metrics[0], desired, rtol=1.25,
err_msg='Wrong performance for MIMO 16QAM and 4x4 Rayleigh channel')


@dec.slow
def test_wifi80211_siso_channel():
seed(17121996)
wifi80211 = Wifi80211(1, number_of_processes=1)
BERs = wifi80211.link_performance(SISOFlatChannel(fading_param=(1 + 0j, 0)), range(0, 9, 2), 10 ** 4, 600)[0]
desired = (0.548, 0.508, 0.59, 0.81, 0.18) # From previous tests
# for i, val in enumerate(desired):
# print((BERs[i] - val) / val)
assert_allclose(BERs, desired, rtol=0.3,
err_msg='Wrong performance for SISO QPSK and AWGN channel')


wifi80211 = Wifi80211(3)
modem = wifi80211.get_modem()


def receiver_mimo_wifi3(y, h, constellation, noise_var):
return modem.demodulate(kbest(y, h, constellation, 16), 'hard')


@dec.slow
def test_wifi80211_mimo_channel():
seed(17121996)
# Apply link_performance to MIMO 16QAM and 4x4 Rayleigh channel
RayleighChannel = MIMOFlatChannel(4, 4)
RayleighChannel.uncorr_rayleigh_fading(complex)

BERs = wifi80211.link_performance(RayleighChannel, arange(0, 21, 5) + 10 * log10(modem.num_bits_symbol), 10 ** 4,
600, receiver=receiver_mimo_wifi3)[0]
desired = (0.535, 0.508, 0.521, 0.554, 0.475) # From previous test
assert_allclose(BERs, desired, rtol=1.25,
err_msg='Wrong performance for MIMO 16QAM and 4x4 Rayleigh channel')


if __name__ == "__main__":
run_module_suite()