Skip to content

Commit 527b030

Browse files
committed
add ans coder
1 parent 86cffd9 commit 527b030

4 files changed

Lines changed: 530 additions & 25 deletions

File tree

cppans.h

Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
#ifndef INC_CPPANS_H_
2+
#define INC_CPPANS_H_
3+
/*
4+
USAGE:
5+
Put '#define CPPANS_IMPLEMENTATION' before including this file to create the implementation.
6+
*/
7+
#define CPPANS_IMPLEMENTATION (1)
8+
#include <cstdint>
9+
10+
namespace cppans
11+
{
12+
using s8 = int8_t;
13+
using s16 = int16_t;
14+
using s32 = int32_t;
15+
using s64 = int64_t;
16+
17+
using u8 = uint8_t;
18+
using u16 = uint16_t;
19+
using u32 = uint32_t;
20+
using u64 = uint64_t;
21+
22+
class rANS
23+
{
24+
public:
25+
inline static constexpr u32 MaxSize = 0x7FFF'FFFFUL;
26+
inline static constexpr u32 ProbBits = 14;
27+
inline static constexpr u32 ProbScale = 1 << ProbBits;
28+
inline static constexpr u32 rANSByteLowBounds = 1UL << 23;
29+
using State = u32;
30+
31+
struct EncSymbol
32+
{
33+
u32 x_max_; //!< upper bound of pre-normalization interval
34+
u32 rcp_freq_;
35+
u32 bias_;
36+
u16 cmpl_freq_; //!< Complemt of frequency: (!<<scale_bits) - freq
37+
u16 rcp_shift_;
38+
};
39+
40+
static u64 calc_encoded_size(u32 size);
41+
static u32 encode(u32 dst_size, u8* dst, u32 src_size, const u8* src);
42+
static u32 decode(u32 dst_size, u8* dst, u32 src_size, const u8* src);
43+
44+
private:
45+
rANS(const rANS&) = delete;
46+
rANS& operator=(const rANS&) = delete;
47+
};
48+
} // namespace cppans
49+
50+
#ifdef CPPANS_IMPLEMENTATION
51+
# include <cassert>
52+
# include <cstring>
53+
# include <immintrin.h>
54+
55+
# if defined(_MSC_VER)
56+
# define CPPANS_RESTRICT __restrict
57+
# elif defined(__gnuc__)
58+
# define CPPANS_RESTRICT __restrict
59+
# elif defined(__clang__)
60+
# define CPPANS_RESTRICT __restrict
61+
# else
62+
# endif
63+
64+
namespace cppans
65+
{
66+
namespace
67+
{
68+
void count(u32* CPPANS_RESTRICT dst, u32 size, const u8* CPPANS_RESTRICT src)
69+
{
70+
::memset(dst, 0, 256 * sizeof(u32));
71+
u32 s = size & ~0x0FUL;
72+
for(u32 i = 0; i < s; i += 16) {
73+
++dst[src[i + 0]];
74+
++dst[src[i + 1]];
75+
++dst[src[i + 2]];
76+
++dst[src[i + 3]];
77+
++dst[src[i + 4]];
78+
++dst[src[i + 5]];
79+
++dst[src[i + 6]];
80+
++dst[src[i + 7]];
81+
82+
++dst[src[i + 8]];
83+
++dst[src[i + 9]];
84+
++dst[src[i + 10]];
85+
++dst[src[i + 11]];
86+
++dst[src[i + 12]];
87+
++dst[src[i + 13]];
88+
++dst[src[i + 14]];
89+
++dst[src[i + 15]];
90+
}
91+
for(u32 i = s; i < size; ++i) {
92+
++dst[src[i]];
93+
}
94+
}
95+
96+
void cumulative(u32* CPPANS_RESTRICT dst, const u32* CPPANS_RESTRICT src)
97+
{
98+
dst[0] = 0;
99+
for(u32 i = 0; i < 256; ++i) {
100+
dst[i + 1] = dst[i] + src[i];
101+
}
102+
}
103+
104+
void normalize(u32* CPPANS_RESTRICT freqs, u32* CPPANS_RESTRICT cum_freqs, u64 target_total)
105+
{
106+
u32 current_total = cum_freqs[256];
107+
for(u32 i = 1; i < 257; ++i) {
108+
cum_freqs[i] = (target_total * cum_freqs[i]) / current_total;
109+
}
110+
for(u32 i = 0; i < 256; ++i) {
111+
if(freqs[i] && cum_freqs[i + 1] == cum_freqs[i]) {
112+
u32 best_freq = ~0UL;
113+
s32 best_steal = -1;
114+
for(s32 j = 0; j < 256; ++j) {
115+
u32 freq = cum_freqs[j + 1] - cum_freqs[j];
116+
if(1 < freq && freq < best_freq) {
117+
best_freq = freq;
118+
best_steal = j;
119+
}
120+
}
121+
assert(-1 != best_steal);
122+
if(static_cast<u32>(best_steal) < i) {
123+
for(s32 j = best_steal + 1; j <= i; ++j) {
124+
--cum_freqs[j];
125+
}
126+
} else {
127+
assert(i < best_steal);
128+
for(s32 j = i + 1; j <= best_steal; ++j) {
129+
++cum_freqs[j];
130+
}
131+
}
132+
}
133+
}
134+
for(u32 i = 0; i < 256; ++i) {
135+
# if _DEBUG
136+
if(0 == freqs[i]) {
137+
assert(cum_freqs[i] == cum_freqs[i + 1]);
138+
} else {
139+
assert(cum_freqs[i] < cum_freqs[i + 1]);
140+
}
141+
# endif
142+
freqs[i] = cum_freqs[i + 1] - cum_freqs[i];
143+
}
144+
}
145+
146+
void init(rANS::EncSymbol& symbol, u32 start, u32 freq, u32 scale_bits)
147+
{
148+
assert(scale_bits <= 16);
149+
assert(start <= (1u << scale_bits));
150+
assert(freq <= (1u << scale_bits) - start);
151+
152+
// Say M := 1 << scale_bits.
153+
//
154+
// The original encoder does:
155+
// x_new = (x/freq)*M + start + (x%freq)
156+
//
157+
// The fast encoder does (schematically):
158+
// q = mul_hi(x, rcp_freq) >> rcp_shift (division)
159+
// r = x - q*freq (remainder)
160+
// x_new = q*M + bias + r (new x)
161+
// plugging in r into x_new yields:
162+
// x_new = bias + x + q*(M - freq)
163+
// =: bias + x + q*cmpl_freq (*)
164+
//
165+
// and we can just precompute cmpl_freq. Now we just need to
166+
// set up our parameters such that the original encoder and
167+
// the fast encoder agree.
168+
169+
symbol.x_max_ = ((rANS::rANSByteLowBounds >> scale_bits) << 8) * freq;
170+
symbol.cmpl_freq_ = static_cast<u16>((1 << scale_bits) - freq);
171+
if(freq < 2) {
172+
// freq=0 symbols are never valid to encode, so it doesn't matter what
173+
// we set our values to.
174+
//
175+
// freq=1 is tricky, since the reciprocal of 1 is 1; unfortunately,
176+
// our fixed-point reciprocal approximation can only multiply by values
177+
// smaller than 1.
178+
//
179+
// So we use the "next best thing": rcp_freq=0xffffffff, rcp_shift=0.
180+
// This gives:
181+
// q = mul_hi(x, rcp_freq) >> rcp_shift
182+
// = mul_hi(x, (1<<32) - 1)) >> 0
183+
// = floor(x - x/(2^32))
184+
// = x - 1 if 1 <= x < 2^32
185+
// and we know that x>0 (x=0 is never in a valid normalization interval).
186+
//
187+
// So we now need to choose the other parameters such that
188+
// x_new = x*M + start
189+
// plug it in:
190+
// x*M + start (desired result)
191+
// = bias + x + q*cmpl_freq (*)
192+
// = bias + x + (x - 1)*(M - 1) (plug in q=x-1, cmpl_freq)
193+
// = bias + 1 + (x - 1)*M
194+
// = x*M + (bias + 1 - M)
195+
//
196+
// so we have start = bias + 1 - M, or equivalently
197+
// bias = start + M - 1.
198+
symbol.rcp_freq_ = ~0u;
199+
symbol.rcp_shift_ = 0;
200+
symbol.bias_ = start + (1 << scale_bits) - 1;
201+
} else {
202+
// Alverson, "Integer Division using reciprocals"
203+
// shift=ceil(log2(freq))
204+
u32 shift = 0;
205+
while(freq > (1UL << shift)) {
206+
shift++;
207+
}
208+
209+
symbol.rcp_freq_ = static_cast<u32>(((1ULL << (shift + 31)) + freq - 1) / freq);
210+
symbol.rcp_shift_ = shift - 1;
211+
212+
// With these values, 'q' is the correct quotient, so we
213+
// have bias=start.
214+
symbol.bias_ = start;
215+
}
216+
}
217+
218+
inline void init(rANS::State& state)
219+
{
220+
state = rANS::rANSByteLowBounds;
221+
}
222+
223+
void put(rANS::State& r, u8*& dst, const rANS::EncSymbol& symbol)
224+
{
225+
assert(0 < symbol.x_max_);
226+
227+
// renormalize
228+
u32 x = r;
229+
u32 x_max = symbol.x_max_;
230+
if(x_max <= x) {
231+
u8* ptr = dst;
232+
do {
233+
*--ptr = static_cast<u8>(x & 0xFFUL);
234+
x >>= 8;
235+
} while(x_max <= x);
236+
dst = ptr;
237+
}
238+
239+
// x = C(s,x)
240+
// NOTE: written this way so we get a 32-bit "multiply high" when
241+
// available. If you're on a 64-bit platform with cheap multiplies
242+
// (e.g. x64), just bake the +32 into rcp_shift.
243+
u32 q = static_cast<u32>((static_cast<uint64_t>(x) * symbol.rcp_freq_) >> 32) >> symbol.rcp_shift_;
244+
r = x + symbol.bias_ + q * symbol.cmpl_freq_;
245+
}
246+
247+
void flush(u8*& dst, const rANS::State& r)
248+
{
249+
u32 x = r;
250+
u8* ptr = dst;
251+
ptr -= 4;
252+
ptr[0] = static_cast<u8>(x >> 0);
253+
ptr[1] = static_cast<u8>(x >> 8);
254+
ptr[2] = static_cast<u8>(x >> 16);
255+
ptr[3] = static_cast<u8>(x >> 24);
256+
dst = ptr;
257+
}
258+
259+
// Initializes a rANS decoder.
260+
// Unlike the encoder, the decoder works forwards as you'd expect.
261+
void init_decode(rANS::State& r, const u8*& ptr)
262+
{
263+
r = ptr[0] << 0;
264+
r |= ptr[1] << 8;
265+
r |= ptr[2] << 16;
266+
r |= ptr[3] << 24;
267+
ptr += 4;
268+
}
269+
270+
// Returns the current cumulative frequency (map it to a symbol yourself!)
271+
inline u32 get(rANS::State& r, u32 scale_bits)
272+
{
273+
return r & ((1UL << scale_bits) - 1);
274+
}
275+
276+
// Advances in the bit stream by "popping" a single symbol with range start
277+
// "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits",
278+
// and the resulting bytes get written to ptr (which is updated).
279+
void advance(rANS::State& r, const u8*& ptr, u32 start, u32 freq, u32 scale_bits)
280+
{
281+
u32 mask = (1UL << scale_bits) - 1;
282+
// s, x = D(x)
283+
u32 x = r;
284+
x = freq * (x >> scale_bits) + (x & mask) - start;
285+
// renormalize
286+
if(x < rANS::rANSByteLowBounds) {
287+
do {
288+
x = (x << 8) | *ptr++;
289+
} while(x < rANS::rANSByteLowBounds);
290+
}
291+
r = x;
292+
}
293+
294+
} // namespace
295+
296+
u64 rANS::calc_encoded_size(u32 size)
297+
{
298+
return static_cast<u64>(size) * 2 + sizeof(u32) * 258;
299+
}
300+
301+
u32 rANS::encode(u32 dst_size, u8* dst, u32 src_size, const u8* src)
302+
{
303+
assert(0 < dst_size);
304+
assert(nullptr != dst);
305+
assert(0 < src_size);
306+
assert(nullptr != src);
307+
308+
u32 freqs[256];
309+
count(freqs, src_size, src);
310+
u32 cum_freqs[257];
311+
cumulative(cum_freqs, freqs);
312+
normalize(freqs, cum_freqs, ProbScale);
313+
EncSymbol symbols[256];
314+
for(u32 i = 0; i < 256; ++i) {
315+
init(symbols[i], cum_freqs[i], freqs[i], ProbBits);
316+
}
317+
State rans;
318+
init(rans);
319+
u8* ptr = dst + dst_size;
320+
for(u32 i = src_size; 0 < i; --i) {
321+
u8 s = src[i - 1];
322+
put(rans, ptr, symbols[s]);
323+
}
324+
flush(ptr, rans);
325+
ptr -= sizeof(u32) * 258;
326+
if(ptr < dst) {
327+
return 0;
328+
}
329+
u32* u32ptr = reinterpret_cast<u32*>(ptr) + 1;
330+
for(u32 i = 0; i < 257; ++i) {
331+
u32ptr[i] = cum_freqs[i];
332+
}
333+
u32ptr[-1] = src_size;
334+
u32 encoded_size = static_cast<u32>(dst + dst_size - ptr);
335+
return encoded_size;
336+
}
337+
338+
u32 rANS::decode(u32 dst_size, u8* dst, u32 src_size, const u8* src)
339+
{
340+
assert(0 < dst_size);
341+
assert(nullptr != dst);
342+
assert(0 < src_size);
343+
assert(nullptr != src);
344+
assert(257 * sizeof(u32) <= src_size);
345+
u32 original_size;
346+
::memcpy(&original_size, src, sizeof(u32));
347+
const u32* cum_freqs = reinterpret_cast<const u32*>(src) + 1;
348+
u8 cum2sym[ProbScale] = {};
349+
for(u32 s = 0; s < 256; ++s) {
350+
for(u32 i = cum_freqs[s]; i < cum_freqs[s + 1]; ++i) {
351+
cum2sym[i] = s;
352+
}
353+
}
354+
355+
const u8* ptr = src + sizeof(u32) * 257;
356+
State rans;
357+
init_decode(rans, ptr);
358+
359+
for(u32 i = 0; i < original_size; ++i) {
360+
u8 s = cum2sym[get(rans, ProbBits)];
361+
dst[i] = s;
362+
u32 freq = cum_freqs[s+1] - cum_freqs[s];
363+
advance(rans, ptr, cum_freqs[s], freq, ProbBits);
364+
}
365+
u32 decoded_size = static_cast<u32>(ptr - (src+sizeof(u32) * 257));
366+
return original_size;
367+
}
368+
} // namespace cppans
369+
#endif
370+
#endif INC_CPPANS_H_

cpprcoder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ Put '#define CPPRCODER_IMPLEMENTATION' before including this file to create the
5050
*/
5151
#include <cassert>
5252
#include <cstdint>
53+
#include <cstring>
5354

5455
#define CPPRCODER_USE_SIMD
5556

0 commit comments

Comments
 (0)