|
| 1 | +#ifndef INC_CPPANS_H_ |
| 2 | +#define INC_CPPANS_H_ |
| 3 | +/* |
| 4 | +USAGE: |
| 5 | +Put '#define CPPANS_IMPLEMENTATION' before including this file to create the implementation. |
| 6 | +*/ |
| 7 | +#define CPPANS_IMPLEMENTATION (1) |
| 8 | +#include <cstdint> |
| 9 | + |
| 10 | +namespace cppans |
| 11 | +{ |
| 12 | +using s8 = int8_t; |
| 13 | +using s16 = int16_t; |
| 14 | +using s32 = int32_t; |
| 15 | +using s64 = int64_t; |
| 16 | + |
| 17 | +using u8 = uint8_t; |
| 18 | +using u16 = uint16_t; |
| 19 | +using u32 = uint32_t; |
| 20 | +using u64 = uint64_t; |
| 21 | + |
| 22 | +class rANS |
| 23 | +{ |
| 24 | +public: |
| 25 | + inline static constexpr u32 MaxSize = 0x7FFF'FFFFUL; |
| 26 | + inline static constexpr u32 ProbBits = 14; |
| 27 | + inline static constexpr u32 ProbScale = 1 << ProbBits; |
| 28 | + inline static constexpr u32 rANSByteLowBounds = 1UL << 23; |
| 29 | + using State = u32; |
| 30 | + |
| 31 | + struct EncSymbol |
| 32 | + { |
| 33 | + u32 x_max_; //!< upper bound of pre-normalization interval |
| 34 | + u32 rcp_freq_; |
| 35 | + u32 bias_; |
| 36 | + u16 cmpl_freq_; //!< Complemt of frequency: (!<<scale_bits) - freq |
| 37 | + u16 rcp_shift_; |
| 38 | + }; |
| 39 | + |
| 40 | + static u64 calc_encoded_size(u32 size); |
| 41 | + static u32 encode(u32 dst_size, u8* dst, u32 src_size, const u8* src); |
| 42 | + static u32 decode(u32 dst_size, u8* dst, u32 src_size, const u8* src); |
| 43 | + |
| 44 | +private: |
| 45 | + rANS(const rANS&) = delete; |
| 46 | + rANS& operator=(const rANS&) = delete; |
| 47 | +}; |
| 48 | +} // namespace cppans |
| 49 | + |
| 50 | +#ifdef CPPANS_IMPLEMENTATION |
| 51 | +# include <cassert> |
| 52 | +# include <cstring> |
| 53 | +# include <immintrin.h> |
| 54 | + |
| 55 | +# if defined(_MSC_VER) |
| 56 | +# define CPPANS_RESTRICT __restrict |
| 57 | +# elif defined(__gnuc__) |
| 58 | +# define CPPANS_RESTRICT __restrict |
| 59 | +# elif defined(__clang__) |
| 60 | +# define CPPANS_RESTRICT __restrict |
| 61 | +# else |
| 62 | +# endif |
| 63 | + |
| 64 | +namespace cppans |
| 65 | +{ |
| 66 | +namespace |
| 67 | +{ |
| 68 | + void count(u32* CPPANS_RESTRICT dst, u32 size, const u8* CPPANS_RESTRICT src) |
| 69 | + { |
| 70 | + ::memset(dst, 0, 256 * sizeof(u32)); |
| 71 | + u32 s = size & ~0x0FUL; |
| 72 | + for(u32 i = 0; i < s; i += 16) { |
| 73 | + ++dst[src[i + 0]]; |
| 74 | + ++dst[src[i + 1]]; |
| 75 | + ++dst[src[i + 2]]; |
| 76 | + ++dst[src[i + 3]]; |
| 77 | + ++dst[src[i + 4]]; |
| 78 | + ++dst[src[i + 5]]; |
| 79 | + ++dst[src[i + 6]]; |
| 80 | + ++dst[src[i + 7]]; |
| 81 | + |
| 82 | + ++dst[src[i + 8]]; |
| 83 | + ++dst[src[i + 9]]; |
| 84 | + ++dst[src[i + 10]]; |
| 85 | + ++dst[src[i + 11]]; |
| 86 | + ++dst[src[i + 12]]; |
| 87 | + ++dst[src[i + 13]]; |
| 88 | + ++dst[src[i + 14]]; |
| 89 | + ++dst[src[i + 15]]; |
| 90 | + } |
| 91 | + for(u32 i = s; i < size; ++i) { |
| 92 | + ++dst[src[i]]; |
| 93 | + } |
| 94 | + } |
| 95 | + |
| 96 | + void cumulative(u32* CPPANS_RESTRICT dst, const u32* CPPANS_RESTRICT src) |
| 97 | + { |
| 98 | + dst[0] = 0; |
| 99 | + for(u32 i = 0; i < 256; ++i) { |
| 100 | + dst[i + 1] = dst[i] + src[i]; |
| 101 | + } |
| 102 | + } |
| 103 | + |
| 104 | + void normalize(u32* CPPANS_RESTRICT freqs, u32* CPPANS_RESTRICT cum_freqs, u64 target_total) |
| 105 | + { |
| 106 | + u32 current_total = cum_freqs[256]; |
| 107 | + for(u32 i = 1; i < 257; ++i) { |
| 108 | + cum_freqs[i] = (target_total * cum_freqs[i]) / current_total; |
| 109 | + } |
| 110 | + for(u32 i = 0; i < 256; ++i) { |
| 111 | + if(freqs[i] && cum_freqs[i + 1] == cum_freqs[i]) { |
| 112 | + u32 best_freq = ~0UL; |
| 113 | + s32 best_steal = -1; |
| 114 | + for(s32 j = 0; j < 256; ++j) { |
| 115 | + u32 freq = cum_freqs[j + 1] - cum_freqs[j]; |
| 116 | + if(1 < freq && freq < best_freq) { |
| 117 | + best_freq = freq; |
| 118 | + best_steal = j; |
| 119 | + } |
| 120 | + } |
| 121 | + assert(-1 != best_steal); |
| 122 | + if(static_cast<u32>(best_steal) < i) { |
| 123 | + for(s32 j = best_steal + 1; j <= i; ++j) { |
| 124 | + --cum_freqs[j]; |
| 125 | + } |
| 126 | + } else { |
| 127 | + assert(i < best_steal); |
| 128 | + for(s32 j = i + 1; j <= best_steal; ++j) { |
| 129 | + ++cum_freqs[j]; |
| 130 | + } |
| 131 | + } |
| 132 | + } |
| 133 | + } |
| 134 | + for(u32 i = 0; i < 256; ++i) { |
| 135 | +# if _DEBUG |
| 136 | + if(0 == freqs[i]) { |
| 137 | + assert(cum_freqs[i] == cum_freqs[i + 1]); |
| 138 | + } else { |
| 139 | + assert(cum_freqs[i] < cum_freqs[i + 1]); |
| 140 | + } |
| 141 | +# endif |
| 142 | + freqs[i] = cum_freqs[i + 1] - cum_freqs[i]; |
| 143 | + } |
| 144 | + } |
| 145 | + |
| 146 | + void init(rANS::EncSymbol& symbol, u32 start, u32 freq, u32 scale_bits) |
| 147 | + { |
| 148 | + assert(scale_bits <= 16); |
| 149 | + assert(start <= (1u << scale_bits)); |
| 150 | + assert(freq <= (1u << scale_bits) - start); |
| 151 | + |
| 152 | + // Say M := 1 << scale_bits. |
| 153 | + // |
| 154 | + // The original encoder does: |
| 155 | + // x_new = (x/freq)*M + start + (x%freq) |
| 156 | + // |
| 157 | + // The fast encoder does (schematically): |
| 158 | + // q = mul_hi(x, rcp_freq) >> rcp_shift (division) |
| 159 | + // r = x - q*freq (remainder) |
| 160 | + // x_new = q*M + bias + r (new x) |
| 161 | + // plugging in r into x_new yields: |
| 162 | + // x_new = bias + x + q*(M - freq) |
| 163 | + // =: bias + x + q*cmpl_freq (*) |
| 164 | + // |
| 165 | + // and we can just precompute cmpl_freq. Now we just need to |
| 166 | + // set up our parameters such that the original encoder and |
| 167 | + // the fast encoder agree. |
| 168 | + |
| 169 | + symbol.x_max_ = ((rANS::rANSByteLowBounds >> scale_bits) << 8) * freq; |
| 170 | + symbol.cmpl_freq_ = static_cast<u16>((1 << scale_bits) - freq); |
| 171 | + if(freq < 2) { |
| 172 | + // freq=0 symbols are never valid to encode, so it doesn't matter what |
| 173 | + // we set our values to. |
| 174 | + // |
| 175 | + // freq=1 is tricky, since the reciprocal of 1 is 1; unfortunately, |
| 176 | + // our fixed-point reciprocal approximation can only multiply by values |
| 177 | + // smaller than 1. |
| 178 | + // |
| 179 | + // So we use the "next best thing": rcp_freq=0xffffffff, rcp_shift=0. |
| 180 | + // This gives: |
| 181 | + // q = mul_hi(x, rcp_freq) >> rcp_shift |
| 182 | + // = mul_hi(x, (1<<32) - 1)) >> 0 |
| 183 | + // = floor(x - x/(2^32)) |
| 184 | + // = x - 1 if 1 <= x < 2^32 |
| 185 | + // and we know that x>0 (x=0 is never in a valid normalization interval). |
| 186 | + // |
| 187 | + // So we now need to choose the other parameters such that |
| 188 | + // x_new = x*M + start |
| 189 | + // plug it in: |
| 190 | + // x*M + start (desired result) |
| 191 | + // = bias + x + q*cmpl_freq (*) |
| 192 | + // = bias + x + (x - 1)*(M - 1) (plug in q=x-1, cmpl_freq) |
| 193 | + // = bias + 1 + (x - 1)*M |
| 194 | + // = x*M + (bias + 1 - M) |
| 195 | + // |
| 196 | + // so we have start = bias + 1 - M, or equivalently |
| 197 | + // bias = start + M - 1. |
| 198 | + symbol.rcp_freq_ = ~0u; |
| 199 | + symbol.rcp_shift_ = 0; |
| 200 | + symbol.bias_ = start + (1 << scale_bits) - 1; |
| 201 | + } else { |
| 202 | + // Alverson, "Integer Division using reciprocals" |
| 203 | + // shift=ceil(log2(freq)) |
| 204 | + u32 shift = 0; |
| 205 | + while(freq > (1UL << shift)) { |
| 206 | + shift++; |
| 207 | + } |
| 208 | + |
| 209 | + symbol.rcp_freq_ = static_cast<u32>(((1ULL << (shift + 31)) + freq - 1) / freq); |
| 210 | + symbol.rcp_shift_ = shift - 1; |
| 211 | + |
| 212 | + // With these values, 'q' is the correct quotient, so we |
| 213 | + // have bias=start. |
| 214 | + symbol.bias_ = start; |
| 215 | + } |
| 216 | + } |
| 217 | + |
| 218 | + inline void init(rANS::State& state) |
| 219 | + { |
| 220 | + state = rANS::rANSByteLowBounds; |
| 221 | + } |
| 222 | + |
| 223 | + void put(rANS::State& r, u8*& dst, const rANS::EncSymbol& symbol) |
| 224 | + { |
| 225 | + assert(0 < symbol.x_max_); |
| 226 | + |
| 227 | + // renormalize |
| 228 | + u32 x = r; |
| 229 | + u32 x_max = symbol.x_max_; |
| 230 | + if(x_max <= x) { |
| 231 | + u8* ptr = dst; |
| 232 | + do { |
| 233 | + *--ptr = static_cast<u8>(x & 0xFFUL); |
| 234 | + x >>= 8; |
| 235 | + } while(x_max <= x); |
| 236 | + dst = ptr; |
| 237 | + } |
| 238 | + |
| 239 | + // x = C(s,x) |
| 240 | + // NOTE: written this way so we get a 32-bit "multiply high" when |
| 241 | + // available. If you're on a 64-bit platform with cheap multiplies |
| 242 | + // (e.g. x64), just bake the +32 into rcp_shift. |
| 243 | + u32 q = static_cast<u32>((static_cast<uint64_t>(x) * symbol.rcp_freq_) >> 32) >> symbol.rcp_shift_; |
| 244 | + r = x + symbol.bias_ + q * symbol.cmpl_freq_; |
| 245 | + } |
| 246 | + |
| 247 | + void flush(u8*& dst, const rANS::State& r) |
| 248 | + { |
| 249 | + u32 x = r; |
| 250 | + u8* ptr = dst; |
| 251 | + ptr -= 4; |
| 252 | + ptr[0] = static_cast<u8>(x >> 0); |
| 253 | + ptr[1] = static_cast<u8>(x >> 8); |
| 254 | + ptr[2] = static_cast<u8>(x >> 16); |
| 255 | + ptr[3] = static_cast<u8>(x >> 24); |
| 256 | + dst = ptr; |
| 257 | + } |
| 258 | + |
| 259 | + // Initializes a rANS decoder. |
| 260 | + // Unlike the encoder, the decoder works forwards as you'd expect. |
| 261 | + void init_decode(rANS::State& r, const u8*& ptr) |
| 262 | + { |
| 263 | + r = ptr[0] << 0; |
| 264 | + r |= ptr[1] << 8; |
| 265 | + r |= ptr[2] << 16; |
| 266 | + r |= ptr[3] << 24; |
| 267 | + ptr += 4; |
| 268 | + } |
| 269 | + |
| 270 | + // Returns the current cumulative frequency (map it to a symbol yourself!) |
| 271 | + inline u32 get(rANS::State& r, u32 scale_bits) |
| 272 | + { |
| 273 | + return r & ((1UL << scale_bits) - 1); |
| 274 | + } |
| 275 | + |
| 276 | + // Advances in the bit stream by "popping" a single symbol with range start |
| 277 | + // "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits", |
| 278 | + // and the resulting bytes get written to ptr (which is updated). |
| 279 | + void advance(rANS::State& r, const u8*& ptr, u32 start, u32 freq, u32 scale_bits) |
| 280 | + { |
| 281 | + u32 mask = (1UL << scale_bits) - 1; |
| 282 | + // s, x = D(x) |
| 283 | + u32 x = r; |
| 284 | + x = freq * (x >> scale_bits) + (x & mask) - start; |
| 285 | + // renormalize |
| 286 | + if(x < rANS::rANSByteLowBounds) { |
| 287 | + do { |
| 288 | + x = (x << 8) | *ptr++; |
| 289 | + } while(x < rANS::rANSByteLowBounds); |
| 290 | + } |
| 291 | + r = x; |
| 292 | + } |
| 293 | + |
| 294 | +} // namespace |
| 295 | + |
| 296 | +u64 rANS::calc_encoded_size(u32 size) |
| 297 | +{ |
| 298 | + return static_cast<u64>(size) * 2 + sizeof(u32) * 258; |
| 299 | +} |
| 300 | + |
| 301 | +u32 rANS::encode(u32 dst_size, u8* dst, u32 src_size, const u8* src) |
| 302 | +{ |
| 303 | + assert(0 < dst_size); |
| 304 | + assert(nullptr != dst); |
| 305 | + assert(0 < src_size); |
| 306 | + assert(nullptr != src); |
| 307 | + |
| 308 | + u32 freqs[256]; |
| 309 | + count(freqs, src_size, src); |
| 310 | + u32 cum_freqs[257]; |
| 311 | + cumulative(cum_freqs, freqs); |
| 312 | + normalize(freqs, cum_freqs, ProbScale); |
| 313 | + EncSymbol symbols[256]; |
| 314 | + for(u32 i = 0; i < 256; ++i) { |
| 315 | + init(symbols[i], cum_freqs[i], freqs[i], ProbBits); |
| 316 | + } |
| 317 | + State rans; |
| 318 | + init(rans); |
| 319 | + u8* ptr = dst + dst_size; |
| 320 | + for(u32 i = src_size; 0 < i; --i) { |
| 321 | + u8 s = src[i - 1]; |
| 322 | + put(rans, ptr, symbols[s]); |
| 323 | + } |
| 324 | + flush(ptr, rans); |
| 325 | + ptr -= sizeof(u32) * 258; |
| 326 | + if(ptr < dst) { |
| 327 | + return 0; |
| 328 | + } |
| 329 | + u32* u32ptr = reinterpret_cast<u32*>(ptr) + 1; |
| 330 | + for(u32 i = 0; i < 257; ++i) { |
| 331 | + u32ptr[i] = cum_freqs[i]; |
| 332 | + } |
| 333 | + u32ptr[-1] = src_size; |
| 334 | + u32 encoded_size = static_cast<u32>(dst + dst_size - ptr); |
| 335 | + return encoded_size; |
| 336 | +} |
| 337 | + |
| 338 | +u32 rANS::decode(u32 dst_size, u8* dst, u32 src_size, const u8* src) |
| 339 | +{ |
| 340 | + assert(0 < dst_size); |
| 341 | + assert(nullptr != dst); |
| 342 | + assert(0 < src_size); |
| 343 | + assert(nullptr != src); |
| 344 | + assert(257 * sizeof(u32) <= src_size); |
| 345 | + u32 original_size; |
| 346 | + ::memcpy(&original_size, src, sizeof(u32)); |
| 347 | + const u32* cum_freqs = reinterpret_cast<const u32*>(src) + 1; |
| 348 | + u8 cum2sym[ProbScale] = {}; |
| 349 | + for(u32 s = 0; s < 256; ++s) { |
| 350 | + for(u32 i = cum_freqs[s]; i < cum_freqs[s + 1]; ++i) { |
| 351 | + cum2sym[i] = s; |
| 352 | + } |
| 353 | + } |
| 354 | + |
| 355 | + const u8* ptr = src + sizeof(u32) * 257; |
| 356 | + State rans; |
| 357 | + init_decode(rans, ptr); |
| 358 | + |
| 359 | + for(u32 i = 0; i < original_size; ++i) { |
| 360 | + u8 s = cum2sym[get(rans, ProbBits)]; |
| 361 | + dst[i] = s; |
| 362 | + u32 freq = cum_freqs[s+1] - cum_freqs[s]; |
| 363 | + advance(rans, ptr, cum_freqs[s], freq, ProbBits); |
| 364 | + } |
| 365 | + u32 decoded_size = static_cast<u32>(ptr - (src+sizeof(u32) * 257)); |
| 366 | + return original_size; |
| 367 | +} |
| 368 | +} // namespace cppans |
| 369 | +#endif |
| 370 | +#endif INC_CPPANS_H_ |
0 commit comments