Explorar el Código

add fma implementation for x86

correctly rounded double precision fma using extended
precision arithmetics for ld80 systems (x87)
nsz hace 13 años
padre
commit
b1cbd70743
Se han modificado 1 ficheros con 143 adiciones y 7 borrados
  1. 143 7
      src/math/fma.c

+ 143 - 7
src/math/fma.c

@@ -1,3 +1,141 @@
+#include <fenv.h>
+#include "libm.h"
+
+#if LDBL_MANT_DIG==64 && LDBL_MAX_EXP==16384
+union ld80 {
+	long double x;
+	struct {
+		uint64_t m;
+		uint16_t e : 15;
+		uint16_t s : 1;
+		uint16_t pad;
+	} bits;
+};
+
+/* exact add, assumes exponent_x >= exponent_y */
+static void add(long double *hi, long double *lo, long double x, long double y)
+{
+	long double r;
+
+	r = x + y;
+	*hi = r;
+	r -= x;
+	*lo = y - r;
+}
+
+/*
+TODO(nsz): probably simpler mul is enough if we assume x and y are doubles
+so last 11bits are all zeros, no subnormals etc
+*/
+/* exact mul, assumes no over/underflow */
+static void mul(long double *hi, long double *lo, long double x, long double y)
+{
+	static const long double c = 1.0 + 0x1p32L;
+	long double cx, xh, xl, cy, yh, yl;
+
+	cx = c*x;
+	xh = (x - cx) + cx;
+	xl = x - xh;
+	cy = c*y;
+	yh = (y - cy) + cy;
+	yl = y - yh;
+	*hi = x*y;
+	*lo = (xh*yh - *hi) + xh*yl + xl*yh + xl*yl;
+}
+
+/*
+assume (long double)(hi+lo) == hi
+return an adjusted hi so that rounding it to double is correct
+*/
+static long double adjust(long double hi, long double lo)
+{
+	union ld80 uhi, ulo;
+
+	if (lo == 0)
+		return hi;
+	uhi.x = hi;
+	if (uhi.bits.m & 0x3ff)
+		return hi;
+	ulo.x = lo;
+	if (uhi.bits.s == ulo.bits.s)
+		uhi.bits.m++;
+	else
+		uhi.bits.m--;
+	return uhi.x;
+}
+
+static long double dadd(long double x, long double y)
+{
+	add(&x, &y, x, y);
+	return adjust(x, y);
+}
+
+static long double dmul(long double x, long double y)
+{
+	mul(&x, &y, x, y);
+	return adjust(x, y);
+}
+
+static int getexp(long double x)
+{
+	union ld80 u;
+	u.x = x;
+	return u.bits.e;
+}
+
+double fma(double x, double y, double z)
+{
+	long double hi, lo1, lo2, xy;
+	int round, ez, exy;
+
+	/* handle +-inf,nan */
+	if (!isfinite(x) || !isfinite(y))
+		return x*y + z;
+	if (!isfinite(z))
+		return z;
+	/* handle +-0 */
+	if (x == 0.0 || y == 0.0)
+		return x*y + z;
+	round = fegetround();
+	if (z == 0.0) {
+		if (round == FE_TONEAREST)
+			return dmul(x, y);
+		return x*y;
+	}
+
+	/* exact mul and add require nearest rounding */
+	/* spurious inexact exceptions may be raised */
+	fesetround(FE_TONEAREST);
+	mul(&xy, &lo1, x, y);
+	exy = getexp(xy);
+	ez = getexp(z);
+	if (ez > exy) {
+		add(&hi, &lo2, z, xy);
+	} else if (ez > exy - 12) {
+		add(&hi, &lo2, xy, z);
+		if (hi == 0) {
+			fesetround(round);
+			/* TODO: verify that the sign of 0 is always correct */
+			return (xy + z) + lo1;
+		}
+	} else {
+		/*
+		ez <= exy - 12
+		the 12 extra bits (1guard, 11round+sticky) are needed so with
+			lo = dadd(lo1, lo2)
+		elo <= ehi - 11, and we use the last 10 bits in adjust so
+			dadd(hi, lo)
+		gives correct result when rounded to double
+		*/
+		hi = xy;
+		lo2 = z;
+	}
+	fesetround(round);
+	if (round == FE_TONEAREST)
+		return dadd(hi, dadd(lo1, lo2));
+	return hi + (lo1 + lo2);
+}
+#else
 /* origin: FreeBSD /usr/src/lib/msun/src/s_fma.c */
 /*-
  * Copyright (c) 2005-2011 David Schultz <[email protected]>
@@ -25,9 +163,6 @@
  * SUCH DAMAGE.
  */
 
-#include <fenv.h>
-#include "libm.h"
-
 /*
  * A struct dd represents a floating-point number with twice the precision
  * of a double.  We maintain the invariant that "hi" stores the 53 high-order
@@ -178,14 +313,14 @@ double fma(double x, double y, double z)
 	 * return values here are crucial in handling special cases involving
 	 * infinities, NaNs, overflows, and signed zeroes correctly.
 	 */
-	if (x == 0.0 || y == 0.0)
-		return (x * y + z);
-	if (z == 0.0)
-		return (x * y);
 	if (!isfinite(x) || !isfinite(y))
 		return (x * y + z);
 	if (!isfinite(z))
 		return (z);
+	if (x == 0.0 || y == 0.0)
+		return (x * y + z);
+	if (z == 0.0)
+		return (x * y);
 
 	xs = frexp(x, &ex);
 	ys = frexp(y, &ey);
@@ -278,3 +413,4 @@ double fma(double x, double y, double z)
 	else
 		return (add_and_denormalize(r.hi, adj, spread));
 }
+#endif