-
Notifications
You must be signed in to change notification settings - Fork 2k
/
Copy pathdecode_image.ts
90 lines (85 loc) · 3.33 KB
/
decode_image.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import {Tensor3D, tensor3d, util} from '@tensorflow/tfjs-core';
import * as jpeg from 'jpeg-js';
enum ImageType {
JPEG = 'jpeg',
PNG = 'png',
GIF = 'gif',
BMP = 'BMP'
}
/**
* Decode a JPEG-encoded image to a 3D Tensor of dtype `int32`.
*
* ```js
* // Load an image as a Uint8Array
* const imageUri = 'http://image-uri-here.example.com/image.jpg'; *
* const response = await fetch(imageUri, {}, { isBinary: true });
* const imageDataArrayBuffer = await response.arrayBuffer();
* const imageData = new Uint8Array(imageDataArrayBuffer);
*
* // Decode image data to a tensor
* const imageTensor = decodeJpeg(imageData);
* ```
*
* @param contents The JPEG-encoded image in an Uint8Array.
* @param channels An optional int. Defaults to 3. Accepted values are
* 0: use the number of channels in the JPG-encoded image.
* 1: output a grayscale image.
* 3: output an RGB image.
* @returns A 3D Tensor of dtype `int32` with shape [height, width, 1/3].
*
* @doc {heading: 'Media', subheading: 'Images'}
*/
export function decodeJpeg(
contents: Uint8Array, channels: 0|1|3 = 3): Tensor3D {
util.assert(
getImageType(contents) === ImageType.JPEG,
() => 'The passed contents are not a valid JPEG image');
util.assert(
channels === 3, () => 'Only 3 channels is supported at this time');
const {width, height, data} = jpeg.decode(contents, {useTArray: true});
// Drop the alpha channel info because jpeg.decode always returns a typedArray
// with 255
const buffer = new Uint8Array(width * height * 3);
let offset = 0; // offset into original data
for (let i = 0; i < buffer.length; i += 3) {
buffer[i] = data[offset];
buffer[i + 1] = data[offset + 1];
buffer[i + 2] = data[offset + 2];
offset += 4;
}
return tensor3d(buffer, [height, width, channels]);
}
/**
* Helper function to get image type based on starting bytes of the image file.
*/
function getImageType(content: Uint8Array): string {
// Classify the contents of a file based on starting bytes (aka magic number:
// tslint:disable-next-line:max-line-length
// https://en.wikipedia.org/wiki/Magic_number_(programming)#Magic_numbers_in_files)
// This aligns with TensorFlow Core code:
// tslint:disable-next-line:max-line-length
// https://github.com/tensorflow/tensorflow/blob/4213d5c1bd921f8d5b7b2dc4bbf1eea78d0b5258/tensorflow/core/kernels/decode_image_op.cc#L44
if (content.length > 3 && content[0] === 255 && content[1] === 216 &&
content[2] === 255) {
// JPEG byte chunk starts with `ff d8 ff`
return ImageType.JPEG;
} else if (
content.length > 4 && content[0] === 71 && content[1] === 73 &&
content[2] === 70 && content[3] === 56) {
// GIF byte chunk starts with `47 49 46 38`
return ImageType.GIF;
} else if (
content.length > 8 && content[0] === 137 && content[1] === 80 &&
content[2] === 78 && content[3] === 71 && content[4] === 13 &&
content[5] === 10 && content[6] === 26 && content[7] === 10) {
// PNG byte chunk starts with `\211 P N G \r \n \032 \n (89 50 4E 47 0D 0A
// 1A 0A)`
return ImageType.PNG;
} else if (content.length > 3 && content[0] === 66 && content[1] === 77) {
// BMP byte chunk starts with `42 4d`
return ImageType.BMP;
} else {
throw new Error(
'Expected image (JPEG, PNG, or GIF), but got unsupported image type');
}
}