{-|

Module      :  Data.Array.BitArray.ST
Copyright   :  (c) Claude Heiland-Allen 2012,2018
License     :  BSD3

Maintainer  :  claude@mathr.co.uk
Stability   :  unstable
Portability :  uses ST

Unboxed mutable bit arrays in the 'ST' monad.

-}
-- almost all is implemented with unsafeIOToST and the IO-based implementation
module Data.Array.BitArray.ST
  ( STBitArray()
  -- * MArray-like interface.
  , getBounds
  , newArray
  , newArray_
  , newListArray
  , readArray
  , writeArray
  , mapArray
  , mapIndices
  , getElems
  , getAssocs
  -- * Conversion to/from immutable bit arrays.
  , freeze
  , thaw
  -- * Construction.
  , copy
  , fill
  -- * Short-circuiting reductions.
  , or
  , and
  , isUniform
  , elemIndex
  -- * Aggregate operations.
  , fold
  , map
  , zipWith
  , popCount
  -- * Unsafe.
  , unsafeReadArray
  , unsafeGetElems
  , unsafeFreeze
  , unsafeThaw
  ) where

import Prelude hiding (and, or, map, zipWith)
import Control.Monad.ST (ST)
import Data.Ix (Ix)

import Control.Monad.ST.Unsafe (unsafeIOToST)
import Data.Array.BitArray.Internal (BitArray)
import Data.Array.BitArray.IO (IOBitArray)
import qualified Data.Array.BitArray.IO as IO

-- | The type of mutable bit arrays.
newtype STBitArray s i = STB (IOBitArray i)

-- | Get the bounds of a bit array.
{-# INLINE getBounds #-}
getBounds :: Ix i => STBitArray s i -> ST s (i, i)
getBounds :: forall i s. Ix i => STBitArray s i -> ST s (i, i)
getBounds (STB IOBitArray i
a) = IO (i, i) -> ST s (i, i)
forall a s. IO a -> ST s a
unsafeIOToST (IOBitArray i -> IO (i, i)
forall i. Ix i => IOBitArray i -> IO (i, i)
IO.getBounds IOBitArray i
a)

-- | Create a new array filled with an initial value.
{-# INLINE newArray #-}
newArray :: Ix i => (i, i) {- ^ bounds -} -> Bool {- ^ initial value -} -> ST s (STBitArray s i)
newArray :: forall i s. Ix i => (i, i) -> Bool -> ST s (STBitArray s i)
newArray (i, i)
bs Bool
b = IOBitArray i -> STBitArray s i
forall s i. IOBitArray i -> STBitArray s i
STB (IOBitArray i -> STBitArray s i)
-> ST s (IOBitArray i) -> ST s (STBitArray s i)
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` IO (IOBitArray i) -> ST s (IOBitArray i)
forall a s. IO a -> ST s a
unsafeIOToST ((i, i) -> Bool -> IO (IOBitArray i)
forall i. Ix i => (i, i) -> Bool -> IO (IOBitArray i)
IO.newArray (i, i)
bs Bool
b)

-- | Create a new array filled with a default initial value ('False').
{-# INLINE newArray_ #-}
newArray_ :: Ix i => (i, i) {- ^ bounds -} -> ST s (STBitArray s i)
newArray_ :: forall i s. Ix i => (i, i) -> ST s (STBitArray s i)
newArray_ (i, i)
bs = IOBitArray i -> STBitArray s i
forall s i. IOBitArray i -> STBitArray s i
STB (IOBitArray i -> STBitArray s i)
-> ST s (IOBitArray i) -> ST s (STBitArray s i)
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` IO (IOBitArray i) -> ST s (IOBitArray i)
forall a s. IO a -> ST s a
unsafeIOToST ((i, i) -> Bool -> IO (IOBitArray i)
forall i. Ix i => (i, i) -> Bool -> IO (IOBitArray i)
IO.newArray (i, i)
bs Bool
False)

-- | Create a new array filled with values from a list.
{-# INLINE newListArray #-}
newListArray :: Ix i => (i, i) {- ^ bounds -} -> [Bool] {- ^ elems -} -> ST s (STBitArray s i)
newListArray :: forall i s. Ix i => (i, i) -> [Bool] -> ST s (STBitArray s i)
newListArray (i, i)
bs [Bool]
es = IOBitArray i -> STBitArray s i
forall s i. IOBitArray i -> STBitArray s i
STB (IOBitArray i -> STBitArray s i)
-> ST s (IOBitArray i) -> ST s (STBitArray s i)
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` IO (IOBitArray i) -> ST s (IOBitArray i)
forall a s. IO a -> ST s a
unsafeIOToST ((i, i) -> [Bool] -> IO (IOBitArray i)
forall i. Ix i => (i, i) -> [Bool] -> IO (IOBitArray i)
IO.newListArray (i, i)
bs [Bool]
es)

-- | Read from an array at an index.
{-# INLINE readArray #-}
readArray :: Ix i => STBitArray s i -> i -> ST s Bool
readArray :: forall i s. Ix i => STBitArray s i -> i -> ST s Bool
readArray (STB IOBitArray i
a) i
i = IO Bool -> ST s Bool
forall a s. IO a -> ST s a
unsafeIOToST (IOBitArray i -> i -> IO Bool
forall i. Ix i => IOBitArray i -> i -> IO Bool
IO.readArray IOBitArray i
a i
i)

-- | Read from an array at an index without bounds checking.  Unsafe.
{-# INLINE unsafeReadArray #-}
unsafeReadArray :: Ix i => STBitArray s i -> i -> ST s Bool
unsafeReadArray :: forall i s. Ix i => STBitArray s i -> i -> ST s Bool
unsafeReadArray (STB IOBitArray i
a) i
i = IO Bool -> ST s Bool
forall a s. IO a -> ST s a
unsafeIOToST (IOBitArray i -> i -> IO Bool
forall i. Ix i => IOBitArray i -> i -> IO Bool
IO.unsafeReadArray IOBitArray i
a i
i)

-- | Write to an array at an index.
{-# INLINE writeArray #-}
writeArray :: Ix i => STBitArray s i -> i -> Bool -> ST s ()
writeArray :: forall i s. Ix i => STBitArray s i -> i -> Bool -> ST s ()
writeArray (STB IOBitArray i
a) i
i Bool
b = IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (IOBitArray i -> i -> Bool -> IO ()
forall i. Ix i => IOBitArray i -> i -> Bool -> IO ()
IO.writeArray IOBitArray i
a i
i Bool
b)

-- | Alias for 'map'.
{-# INLINE mapArray #-}
mapArray :: Ix i => (Bool -> Bool) -> STBitArray s i -> ST s (STBitArray s i)
mapArray :: forall i s.
Ix i =>
(Bool -> Bool) -> STBitArray s i -> ST s (STBitArray s i)
mapArray = (Bool -> Bool) -> STBitArray s i -> ST s (STBitArray s i)
forall i s.
Ix i =>
(Bool -> Bool) -> STBitArray s i -> ST s (STBitArray s i)
map

-- | Create a new array by reading from another.
{-# INLINE mapIndices #-}
mapIndices :: (Ix i, Ix j) => (i, i) {- ^ new bounds -} -> (i -> j) {- ^ index transformation -} -> STBitArray s j {- ^ source array -} -> ST s (STBitArray s i)
mapIndices :: forall i j s.
(Ix i, Ix j) =>
(i, i) -> (i -> j) -> STBitArray s j -> ST s (STBitArray s i)
mapIndices (i, i)
bs i -> j
h (STB IOBitArray j
a) = IOBitArray i -> STBitArray s i
forall s i. IOBitArray i -> STBitArray s i
STB (IOBitArray i -> STBitArray s i)
-> ST s (IOBitArray i) -> ST s (STBitArray s i)
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` IO (IOBitArray i) -> ST s (IOBitArray i)
forall a s. IO a -> ST s a
unsafeIOToST ((i, i) -> (i -> j) -> IOBitArray j -> IO (IOBitArray i)
forall i j.
(Ix i, Ix j) =>
(i, i) -> (i -> j) -> IOBitArray j -> IO (IOBitArray i)
IO.mapIndices (i, i)
bs i -> j
h IOBitArray j
a)

-- | Get a list of all elements of an array.
{-# INLINE getElems #-}
getElems :: Ix i => STBitArray s i -> ST s [Bool]
getElems :: forall i s. Ix i => STBitArray s i -> ST s [Bool]
getElems (STB IOBitArray i
a) = IO [Bool] -> ST s [Bool]
forall a s. IO a -> ST s a
unsafeIOToST (IOBitArray i -> IO [Bool]
forall i. Ix i => IOBitArray i -> IO [Bool]
IO.getElems IOBitArray i
a)

-- | Get a list of all elements of an array without copying.  Unsafe when
--   the source array can be modified later.
{-# INLINE unsafeGetElems #-}
unsafeGetElems :: Ix i => STBitArray s i -> ST s [Bool]
unsafeGetElems :: forall i s. Ix i => STBitArray s i -> ST s [Bool]
unsafeGetElems (STB IOBitArray i
a) = IO [Bool] -> ST s [Bool]
forall a s. IO a -> ST s a
unsafeIOToST (IOBitArray i -> IO [Bool]
forall i. Ix i => IOBitArray i -> IO [Bool]
IO.unsafeGetElems IOBitArray i
a)

-- | Get a list of all (index, element) pairs.
{-# INLINE getAssocs #-}
getAssocs :: Ix i => STBitArray s i -> ST s [(i, Bool)]
getAssocs :: forall i s. Ix i => STBitArray s i -> ST s [(i, Bool)]
getAssocs (STB IOBitArray i
a) = IO [(i, Bool)] -> ST s [(i, Bool)]
forall a s. IO a -> ST s a
unsafeIOToST (IOBitArray i -> IO [(i, Bool)]
forall i. Ix i => IOBitArray i -> IO [(i, Bool)]
IO.getAssocs IOBitArray i
a)

-- | Snapshot the array into an immutable form.
{-# INLINE freeze #-}
freeze :: Ix i => STBitArray s i -> ST s (BitArray i)
freeze :: forall i s. Ix i => STBitArray s i -> ST s (BitArray i)
freeze (STB IOBitArray i
a) = IO (BitArray i) -> ST s (BitArray i)
forall a s. IO a -> ST s a
unsafeIOToST (IOBitArray i -> IO (BitArray i)
forall i. Ix i => IOBitArray i -> IO (BitArray i)
IO.freeze IOBitArray i
a)

-- | Snapshot the array into an immutable form.  Unsafe when the source
--   array can be modified later.
{-# INLINE unsafeFreeze #-}
unsafeFreeze :: Ix i => STBitArray s i -> ST s (BitArray i)
unsafeFreeze :: forall i s. Ix i => STBitArray s i -> ST s (BitArray i)
unsafeFreeze (STB IOBitArray i
a) = IO (BitArray i) -> ST s (BitArray i)
forall a s. IO a -> ST s a
unsafeIOToST (IOBitArray i -> IO (BitArray i)
forall i. Ix i => IOBitArray i -> IO (BitArray i)
IO.unsafeFreeze IOBitArray i
a)

-- | Convert an array from immutable form.
{-# INLINE thaw #-}
thaw :: Ix i => BitArray i -> ST s (STBitArray s i)
thaw :: forall i s. Ix i => BitArray i -> ST s (STBitArray s i)
thaw BitArray i
a = IOBitArray i -> STBitArray s i
forall s i. IOBitArray i -> STBitArray s i
STB (IOBitArray i -> STBitArray s i)
-> ST s (IOBitArray i) -> ST s (STBitArray s i)
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` IO (IOBitArray i) -> ST s (IOBitArray i)
forall a s. IO a -> ST s a
unsafeIOToST (BitArray i -> IO (IOBitArray i)
forall i. Ix i => BitArray i -> IO (IOBitArray i)
IO.thaw BitArray i
a)

-- | Convert an array from immutable form.  Unsafe to modify the result
--   unless the source array is never used later.
{-# INLINE unsafeThaw #-}
unsafeThaw :: Ix i => BitArray i -> ST s (STBitArray s i)
unsafeThaw :: forall i s. Ix i => BitArray i -> ST s (STBitArray s i)
unsafeThaw BitArray i
a = IOBitArray i -> STBitArray s i
forall s i. IOBitArray i -> STBitArray s i
STB (IOBitArray i -> STBitArray s i)
-> ST s (IOBitArray i) -> ST s (STBitArray s i)
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` IO (IOBitArray i) -> ST s (IOBitArray i)
forall a s. IO a -> ST s a
unsafeIOToST (BitArray i -> IO (IOBitArray i)
forall i. Ix i => BitArray i -> IO (IOBitArray i)
IO.unsafeThaw BitArray i
a)

-- | Copy an array.
{-# INLINE copy #-}
copy :: Ix i => STBitArray s i -> ST s (STBitArray s i)
copy :: forall i s. Ix i => STBitArray s i -> ST s (STBitArray s i)
copy (STB IOBitArray i
a) = IOBitArray i -> STBitArray s i
forall s i. IOBitArray i -> STBitArray s i
STB (IOBitArray i -> STBitArray s i)
-> ST s (IOBitArray i) -> ST s (STBitArray s i)
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` IO (IOBitArray i) -> ST s (IOBitArray i)
forall a s. IO a -> ST s a
unsafeIOToST (IOBitArray i -> IO (IOBitArray i)
forall i. Ix i => IOBitArray i -> IO (IOBitArray i)
IO.copy IOBitArray i
a)

-- | Fill an array with a uniform value.
{-# INLINE fill #-}
fill :: Ix i => STBitArray s i -> Bool -> ST s ()
fill :: forall i s. Ix i => STBitArray s i -> Bool -> ST s ()
fill (STB IOBitArray i
a) Bool
b = IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (IOBitArray i -> Bool -> IO ()
forall i. Ix i => IOBitArray i -> Bool -> IO ()
IO.fill IOBitArray i
a Bool
b)

-- | Short-circuit bitwise reduction: True when any bit is True.
{-# INLINE or #-}
or :: Ix i => STBitArray s i -> ST s Bool
or :: forall i s. Ix i => STBitArray s i -> ST s Bool
or (STB IOBitArray i
a) = IO Bool -> ST s Bool
forall a s. IO a -> ST s a
unsafeIOToST (IOBitArray i -> IO Bool
forall i. Ix i => IOBitArray i -> IO Bool
IO.or IOBitArray i
a)

-- | Short-circuit bitwise reduction: False when any bit is False.
{-# INLINE and #-}
and :: Ix i => STBitArray s i -> ST s Bool
and :: forall i s. Ix i => STBitArray s i -> ST s Bool
and (STB IOBitArray i
a) = IO Bool -> ST s Bool
forall a s. IO a -> ST s a
unsafeIOToST (IOBitArray i -> IO Bool
forall i. Ix i => IOBitArray i -> IO Bool
IO.and IOBitArray i
a)

-- | Short-circuit bitwise reduction: 'Nothing' when any bits differ,
--   'Just' when all bits are the same.
{-# INLINE isUniform #-}
isUniform :: Ix i => STBitArray s i -> ST s (Maybe Bool)
isUniform :: forall i s. Ix i => STBitArray s i -> ST s (Maybe Bool)
isUniform (STB IOBitArray i
a) = IO (Maybe Bool) -> ST s (Maybe Bool)
forall a s. IO a -> ST s a
unsafeIOToST (IOBitArray i -> IO (Maybe Bool)
forall i. Ix i => IOBitArray i -> IO (Maybe Bool)
IO.isUniform IOBitArray i
a)

-- | Look up index of first matching bit.
--
--   Note that the index type is limited to Int because there
--   is no 'unindex' method in the 'Ix' class.
{-# INLINE elemIndex #-}
elemIndex :: Bool -> STBitArray s Int -> ST s (Maybe Int)
elemIndex :: forall s. Bool -> STBitArray s Int -> ST s (Maybe Int)
elemIndex Bool
b (STB IOBitArray Int
a) = IO (Maybe Int) -> ST s (Maybe Int)
forall a s. IO a -> ST s a
unsafeIOToST (Bool -> IOBitArray Int -> IO (Maybe Int)
IO.elemIndex Bool
b IOBitArray Int
a)

-- | Bitwise reduction with an associative commutative boolean operator.
--   Implementation lifts from 'Bool' to 'Bits' and folds large chunks
--   at a time.  Each bit is used as a source exactly once.
{-# INLINE fold #-}
fold :: Ix i => (Bool -> Bool -> Bool) {- ^ operator -} -> STBitArray s i -> ST s (Maybe Bool)
fold :: forall i s.
Ix i =>
(Bool -> Bool -> Bool) -> STBitArray s i -> ST s (Maybe Bool)
fold Bool -> Bool -> Bool
f (STB IOBitArray i
a) = IO (Maybe Bool) -> ST s (Maybe Bool)
forall a s. IO a -> ST s a
unsafeIOToST ((Bool -> Bool -> Bool) -> IOBitArray i -> IO (Maybe Bool)
forall i.
Ix i =>
(Bool -> Bool -> Bool) -> IOBitArray i -> IO (Maybe Bool)
IO.fold Bool -> Bool -> Bool
f IOBitArray i
a)

-- | Bitwise map.  Implementation lifts from 'Bool' to 'Bits' and maps
--   large chunks at a time.
{-# INLINE map #-}
map :: Ix i => (Bool -> Bool) -> STBitArray s i -> ST s (STBitArray s i)
map :: forall i s.
Ix i =>
(Bool -> Bool) -> STBitArray s i -> ST s (STBitArray s i)
map Bool -> Bool
f (STB IOBitArray i
a) = IOBitArray i -> STBitArray s i
forall s i. IOBitArray i -> STBitArray s i
STB (IOBitArray i -> STBitArray s i)
-> ST s (IOBitArray i) -> ST s (STBitArray s i)
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` IO (IOBitArray i) -> ST s (IOBitArray i)
forall a s. IO a -> ST s a
unsafeIOToST ((Bool -> Bool) -> IOBitArray i -> IO (IOBitArray i)
forall i.
Ix i =>
(Bool -> Bool) -> IOBitArray i -> IO (IOBitArray i)
IO.map Bool -> Bool
f IOBitArray i
a)

-- | Bitwise zipWith.  Implementation lifts from 'Bool' to 'Bits' and
--   combines large chunks at a time.
--
--   The bounds of the source arrays must be identical.
{-# INLINE zipWith #-}
zipWith :: Ix i => (Bool -> Bool -> Bool) -> STBitArray s i -> STBitArray s i -> ST s (STBitArray s i)
zipWith :: forall i s.
Ix i =>
(Bool -> Bool -> Bool)
-> STBitArray s i -> STBitArray s i -> ST s (STBitArray s i)
zipWith Bool -> Bool -> Bool
f (STB IOBitArray i
a) (STB IOBitArray i
b) = IOBitArray i -> STBitArray s i
forall s i. IOBitArray i -> STBitArray s i
STB (IOBitArray i -> STBitArray s i)
-> ST s (IOBitArray i) -> ST s (STBitArray s i)
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` IO (IOBitArray i) -> ST s (IOBitArray i)
forall a s. IO a -> ST s a
unsafeIOToST ((Bool -> Bool -> Bool)
-> IOBitArray i -> IOBitArray i -> IO (IOBitArray i)
forall i.
Ix i =>
(Bool -> Bool -> Bool)
-> IOBitArray i -> IOBitArray i -> IO (IOBitArray i)
IO.zipWith Bool -> Bool -> Bool
f IOBitArray i
a IOBitArray i
b)

-- | Count set bits.
{-# INLINE popCount #-}
popCount :: Ix i => STBitArray s i -> ST s Int
popCount :: forall i s. Ix i => STBitArray s i -> ST s Int
popCount (STB IOBitArray i
a) = IO Int -> ST s Int
forall a s. IO a -> ST s a
unsafeIOToST (IOBitArray i -> IO Int
forall i. Ix i => IOBitArray i -> IO Int
IO.popCount IOBitArray i
a)