{-|
  Maintainer: Thomas.DuBuisson@gmail.com
  Stability: beta
  Portability: portable

PKCS5 (RFC 1423) and IPSec ESP (RFC 4303)
padding methods are implemented both as trivial functions operating on
bytestrings and as 'Put' routines usable from the "Data.Serialize"
module.  These methods do not work for algorithms or pad sizes in
excess of 255 bytes (2040 bits, so extremely large as far as cipher
needs are concerned).

-}

module Crypto.Padding
        (
        -- * PKCS5 (RFC 1423) based [un]padding routines
          padPKCS5
        , padBlockSize
        , putPaddedPKCS5
        , unpadPKCS5safe
        , unpadPKCS5
        -- * ESP (RFC 4303) [un]padding routines
        , padESP, unpadESP
        , padESPBlockSize
        , putPadESPBlockSize, putPadESP
        ) where

import Data.Serialize.Put
import Crypto.Classes
import Crypto.Types
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L

-- |PKCS5 (aka RFC1423) padding method.
-- This method will not work properly for pad modulos > 256
padPKCS5 :: ByteLength -> B.ByteString -> B.ByteString
padPKCS5 :: ByteLength -> ByteString -> ByteString
padPKCS5 len :: ByteLength
len bs :: ByteString
bs = Put -> ByteString
runPut (Put -> ByteString) -> Put -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteLength -> ByteString -> Put
putPaddedPKCS5 ByteLength
len ByteString
bs

-- | Ex:
--
-- @
--     putPaddedPKCS5 m bs
-- @
--
-- Will pad out `bs` to a byte multiple
-- of `m` and put both the bytestring and it's padding via 'Put'
-- (this saves on copying if you are already using Cereal).
putPaddedPKCS5 :: ByteLength -> B.ByteString -> Put
putPaddedPKCS5 :: ByteLength -> ByteString -> Put
putPaddedPKCS5 0 bs :: ByteString
bs = ByteString -> Put
putByteString ByteString
bs Put -> Put -> Put
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Putter Word8
putWord8 1
putPaddedPKCS5 len :: ByteLength
len bs :: ByteString
bs = ByteString -> Put
putByteString ByteString
bs Put -> Put -> Put
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> Put
putByteString ByteString
pad
  where
  pad :: ByteString
pad = ByteLength -> Word8 -> ByteString
B.replicate ByteLength
padLen Word8
padValue
  r :: ByteLength
r   = ByteLength
len ByteLength -> ByteLength -> ByteLength
forall a. Num a => a -> a -> a
- (ByteString -> ByteLength
B.length ByteString
bs ByteLength -> ByteLength -> ByteLength
forall a. Integral a => a -> a -> a
`rem` ByteLength
len)
  padLen :: ByteLength
padLen = if ByteLength
r ByteLength -> ByteLength -> Bool
forall a. Eq a => a -> a -> Bool
== 0 then ByteLength
len else ByteLength
r
  padValue :: Word8
padValue = ByteLength -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ByteLength
padLen

-- |PKCS5 (aka RFC1423) padding method using the BlockCipher instance
-- to determine the pad size.
padBlockSize :: BlockCipher k => k -> B.ByteString -> B.ByteString
padBlockSize :: k -> ByteString -> ByteString
padBlockSize k :: k
k = Put -> ByteString
runPut (Put -> ByteString)
-> (ByteString -> Put) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. k -> ByteString -> Put
forall k. BlockCipher k => k -> ByteString -> Put
putPaddedBlockSize k
k

-- |Leverages 'putPaddedPKCS5' to put the bytestring and padding
-- of sufficient length for use by the specified block cipher.
putPaddedBlockSize :: BlockCipher k => k -> B.ByteString -> Put
putPaddedBlockSize :: k -> ByteString -> Put
putPaddedBlockSize k :: k
k bs :: ByteString
bs = ByteLength -> ByteString -> Put
putPaddedPKCS5 (Tagged k ByteLength
forall k. BlockCipher k => Tagged k ByteLength
blockSizeBytes Tagged k ByteLength -> k -> ByteLength
forall a b. Tagged a b -> a -> b
`for` k
k) ByteString
bs

-- | unpad a strict bytestring padded in the typical PKCS5 manner.
-- This routine verifies all pad bytes and pad length match correctly.
unpadPKCS5safe :: B.ByteString -> Maybe B.ByteString
unpadPKCS5safe :: ByteString -> Maybe ByteString
unpadPKCS5safe bs :: ByteString
bs
        | ByteLength
bsLen ByteLength -> ByteLength -> Bool
forall a. Ord a => a -> a -> Bool
> 0 Bool -> Bool -> Bool
&& (Word8 -> Bool) -> ByteString -> Bool
B.all (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
padLen) ByteString
pad Bool -> Bool -> Bool
&& ByteString -> ByteLength
B.length ByteString
pad ByteLength -> ByteLength -> Bool
forall a. Eq a => a -> a -> Bool
== ByteLength
pLen = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
msg
        | Bool
otherwise = Maybe ByteString
forall a. Maybe a
Nothing
  where
  bsLen :: ByteLength
bsLen = ByteString -> ByteLength
B.length ByteString
bs
  padLen :: Word8
padLen = ByteString -> Word8
B.last ByteString
bs
  pLen :: ByteLength
pLen = Word8 -> ByteLength
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
padLen
  (msg :: ByteString
msg,pad :: ByteString
pad) = ByteLength -> ByteString -> (ByteString, ByteString)
B.splitAt (ByteLength
bsLen ByteLength -> ByteLength -> ByteLength
forall a. Num a => a -> a -> a
- ByteLength
pLen) ByteString
bs

-- |unpad a strict bytestring without checking the pad bytes and
-- length any more than necessary.
unpadPKCS5 :: B.ByteString -> B.ByteString
unpadPKCS5 :: ByteString -> ByteString
unpadPKCS5 bs :: ByteString
bs = if ByteLength
bsLen ByteLength -> ByteLength -> Bool
forall a. Eq a => a -> a -> Bool
== 0 then ByteString
bs else ByteString
msg
  where
  bsLen :: ByteLength
bsLen = ByteString -> ByteLength
B.length ByteString
bs
  padLen :: Word8
padLen = ByteString -> Word8
B.last ByteString
bs
  pLen :: ByteLength
pLen = Word8 -> ByteLength
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
padLen
  (msg :: ByteString
msg,_) = ByteLength -> ByteString -> (ByteString, ByteString)
B.splitAt (ByteLength
bsLen ByteLength -> ByteLength -> ByteLength
forall a. Num a => a -> a -> a
- ByteLength
pLen) ByteString
bs

-- | Pad a bytestring to the IPSEC esp specification
--
-- > padESP m payload
--
-- is equivilent to:
-- 
-- @
--               (msg)       (padding)       (length field)
--     B.concat [payload, B.pack [1,2,3,4..], B.pack [padLen]]
-- @
--
-- Where:
--
-- * the msg is any payload, including TFC.
-- 
-- * the padding is <= 255
-- 
-- * the length field is one byte.
--
--  Notice the result bytesting length remainder `r` equals zero.  The lack
--  of a \"next header\" field means this function is not directly useable for
--  an IPSec implementation (copy/paste the 4 line function and add in a
--  \"next header\" field if you are making IPSec ESP).
padESP :: Int -> B.ByteString -> B.ByteString
padESP :: ByteLength -> ByteString -> ByteString
padESP i :: ByteLength
i bs :: ByteString
bs = Put -> ByteString
runPut (ByteLength -> ByteString -> Put
putPadESP ByteLength
i ByteString
bs)

-- | Like padESP but use the BlockCipher instance to determine padding size
padESPBlockSize :: BlockCipher k => k -> B.ByteString -> B.ByteString
padESPBlockSize :: k -> ByteString -> ByteString
padESPBlockSize k :: k
k bs :: ByteString
bs = Put -> ByteString
runPut (k -> ByteString -> Put
forall k. BlockCipher k => k -> ByteString -> Put
putPadESPBlockSize k
k ByteString
bs)

-- | Like putPadESP but using the BlockCipher instance to determine padding size
putPadESPBlockSize :: BlockCipher k => k -> B.ByteString -> Put
putPadESPBlockSize :: k -> ByteString -> Put
putPadESPBlockSize k :: k
k bs :: ByteString
bs = ByteLength -> ByteString -> Put
putPadESP (Tagged k ByteLength
forall k. BlockCipher k => Tagged k ByteLength
blockSizeBytes Tagged k ByteLength -> k -> ByteLength
forall a b. Tagged a b -> a -> b
`for` k
k) ByteString
bs

-- | Pad a bytestring to the IPSEC ESP specification using 'Put'.
-- This can reduce copying if you are already using 'Put'.
putPadESP :: Int -> B.ByteString -> Put
putPadESP :: ByteLength -> ByteString -> Put
putPadESP 0 bs :: ByteString
bs = ByteString -> Put
putByteString ByteString
bs Put -> Put -> Put
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Putter Word8
putWord8 0
putPadESP l :: ByteLength
l bs :: ByteString
bs = do
        ByteString -> Put
putByteString ByteString
bs
        ByteString -> Put
putByteString ByteString
pad
        Putter Word8
putWord8 Word8
pLen
  where
  pad :: ByteString
pad = ByteLength -> ByteString -> ByteString
B.take ByteLength
padLen ByteString
espPad
  padLen :: ByteLength
padLen = ByteLength
l ByteLength -> ByteLength -> ByteLength
forall a. Num a => a -> a -> a
- ((ByteString -> ByteLength
B.length ByteString
bs ByteLength -> ByteLength -> ByteLength
forall a. Num a => a -> a -> a
+ 1) ByteLength -> ByteLength -> ByteLength
forall a. Integral a => a -> a -> a
`rem` ByteLength
l)
  pLen :: Word8
pLen = ByteLength -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ByteLength
padLen

-- |A static espPad allows reuse of a single B.pack'ed pad for all calls to padESP
espPad :: ByteString
espPad = [Word8] -> ByteString
B.pack [1..255]

-- | unpad and return the padded message ('Nothing' is returned if the padding is invalid)
unpadESP :: B.ByteString -> Maybe B.ByteString
unpadESP :: ByteString -> Maybe ByteString
unpadESP bs :: ByteString
bs =
        if ByteLength
bsLen ByteLength -> ByteLength -> Bool
forall a. Eq a => a -> a -> Bool
== 0 Bool -> Bool -> Bool
|| Bool -> Bool
not (ByteString -> ByteString -> Bool
constTimeEq (ByteLength -> ByteString -> ByteString
B.take ByteLength
pLen ByteString
pad) (ByteLength -> ByteString -> ByteString
B.take ByteLength
pLen ByteString
espPad))
                then Maybe ByteString
forall a. Maybe a
Nothing
                else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
msg
  where
  bsLen :: ByteLength
bsLen  = ByteString -> ByteLength
B.length ByteString
bs
  padLen :: Word8
padLen = ByteString -> Word8
B.last ByteString
bs
  pLen :: ByteLength
pLen   = Word8 -> ByteLength
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
padLen
  (msg :: ByteString
msg,pad :: ByteString
pad) = ByteLength -> ByteString -> (ByteString, ByteString)
B.splitAt (ByteLength
bsLen ByteLength -> ByteLength -> ByteLength
forall a. Num a => a -> a -> a
- (ByteLength
pLen ByteLength -> ByteLength -> ByteLength
forall a. Num a => a -> a -> a
+ 1)) ByteString
bs