需要源码和数据集请点赞关注收藏后评论区留言私信~~~

一、OCR文字识别简介

利用计算机自动识别字符的技术,是模式识别应用的一个重要领域。人们在生产和生活中,要处理大量的文字、报表和文本。为了减轻人们的劳动,提高处理效率,从上世纪50年代起就开始探讨文字识别方法,并研制出光学字符识别器。

OCR(Optical Character Recognition)图像文字识别是人工智能的重要分支,赋予计算机人眼的功能,使其可以看图识字,图像文字识别系统流程一般分为图像采集、文字检测、文字识别以及结果输出四部分。

二、OCR文字识别项目实战

1:数据集简介

MSRA-TD500该数据集共包含500 张自然场景图像,其分辨率在1296 ´ 864至920 ´ 1280 之间,涵盖了室内商场、标识牌、室外街道、广告牌等大多数场,文本包含中文和英文,有着不同的字体、大小和倾斜方向,部分数据集图像如下图所示。

数据集项目结构如下 分为训练集和测试集

2:项目结构

整体项目结构如下 上面是一些算法和模型比如CRAFT CRNN的定义,下面是测试代码

CRAFT算法实现文本行的检测如图下图所示。首先将完整的文字区域输入CRAFT文字检测网络,得到字符级的文字得分结果热图(Text Score)和字符级文本连接得分热图(Link Score),最后根据连通域得到每个文本行的位置

3:效果展示

开始运行代码

输出运行结果 可以放入不同图片进行测试

三、代码

部分代码如下 需要全部代码和数据集请点赞关注收藏后评论区留言私信~~~

"""This script demonstrates how to train the modelon the SynthText90 using multiple GPUs."""# pylint: disable=invalid-nameimport datetimeimport argparseimport mathimport randomimport stringimport functoolsimport itertoolsimport osimport tarfileimport urllib.requestimport numpy as npimport cv2import imgaugimport tqdmimport tensorflow as tfimport keras_ocr# pylint: disable=redefined-outer-namedef get_filepaths(data_path, split):"""Get the list of filepaths for a given split (train, val, or test)."""with open(os.path.join(data_path, f'mnt/ramdisk/max/90kDICT32px/annotation_{split}.txt'),'r') as text_file:filepaths = [os.path.join(data_path, 'mnt/ramdisk/max/90kDICT32px', line.split(' ')[0][2:]) for line in text_file.readlines()]return filepaths# pylint: disable=redefined-outer-namedef download_extract_and_process_dataset(data_path):"""Download and extract the synthtext90 dataset."""archive_filepath = os.path.join(data_path, 'mjsynth.tar.gz')extraction_directory = os.path.join(data_path, 'mnt')if not os.path.isfile(archive_filepath) and not os.path.isdir(extraction_directory):print('Downloading the dataset.')urllib.request.urlretrieve("https://www.robots.ox.ac.uk/~vgg/data/text/mjsynth.tar.gz", archive_filepath)if not os.path.isdir(extraction_directory):print('Extracting files.')with tarfile.open(os.path.join(data_path, 'mjsynth.tar.gz')) as tfile:tfile.extractall(data_path)def get_image_generator(filepaths, augmenter, width, height):"""Get an image generator for a list of SynthText90 filepaths."""filepaths = filepaths.copy()for filepath in itertools.cycle(filepaths):text = filepath.split(os.sep)[-1].split('_')[1].lower()image = cv2.imread(filepath)if image is None:print(f'An error occurred reading: {filepath}')image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)image = keras_ocr.tools.fit(image,width=width,height=height,cval=np.random.randint(low=0, high=255, size=3).astype('uint8'))if augmenter is not None:image = augmenter.augment_image(image)if filepath == filepaths[-1]:random.shuffle(filepaths)yield image, textif __name__ == '__main__':parser = argparse.ArgumentParser(description='Process some integers.')parser.add_argument('--model_id',default='recognizer',help='The name to use for saving model checkpoints.')parser.add_argument('--data_path',default='.',help='The path to the directory containing the dataset and where we will put our logs.')parser.add_argument('--logs_path',default='./logs',help=('The path to where logs and checkpoints should be stored. ''If a checkpoint matching "model_id" is found, training will resume from that point.'))parser.add_argument('--batch_size', default=16, help='The training batch size to use.')parser.add_argument('--no-file-verification', dest='verify_files', action='store_false')parser.set_defaults(verify_files=True)args = parser.parse_args()weights_path = os.path.join(args.logs_path, args.model_id + '.h5')csv_path = os.path.join(args.logs_path, args.model_id + '.csv')download_extract_and_process_dataset(args.data_path)with tf.distribute.MirroredStrategy().scope():recognizer = keras_ocr.recognition.Recognizer(alphabet=string.digits +string.ascii_lowercase,height=31,width=200,stn=False,optimizer=tf.keras.optimizers.RMSprop(),weights=None)if os.path.isfile(weights_path):print('Loading saved weights and creating new version.')dt_string = datetime.datetime.now().isoformat()weights_path = os.path.join(args.logs_path, args.model_id + '_' + dt_string + '.h5')csv_path = os.path.join(args.logs_path, args.model_id + '_' + dt_string + '.csv')recognizer.model.load_weights(weights_path)augmenter = imgaug.augmenters.Sequential([imgaug.augmenters.Multiply((0.9, 1.1)),imgaug.augmenters.GammaContrast(gamma=(0.5, 3.0)),imgaug.augmenters.Invert(0.25, per_channel=0.5)])os.makedirs(args.logs_path, exist_ok=True)training_filepaths, validation_filepaths = [get_filepaths(data_path=args.data_path, split=split) for split in ['train', 'val']]if args.verify_files:assert all(os.path.isfile(filepath) forfilepath in tqdm.tqdm(training_filepaths + validation_filepaths,desc='Checking filepaths.')), 'Some files appear to be missing.'(training_image_generator, training_steps), (validation_image_generator, validation_steps) = [(get_image_generator(filepaths=filepaths,augmenter=augmenter,width=recognizer.model.input_shape[2],height=recognizer.model.input_shape[1],), math.ceil(len(filepaths) / args.batch_size))for filepaths, augmenter in [(training_filepaths, augmenter), (validation_filepaths, None)]]training_generator, validation_generator = [tf.data.Dataset.from_generator(functools.partial(recognizer.get_batch_generator,image_generator=image_generator,batch_size=args.batch_size),output_types=((tf.float32, tf.int64, tf.float64, tf.int64), tf.float64),output_shapes=((tf.TensorShape([None, 31, 200, 1]), tf.TensorShape([None, recognizer.training_model.input_shape[1][1]]), tf.TensorShape([None,1]), tf.TensorShape([None, 1])), tf.TensorShape([None, 1])))for image_generator in [training_image_generator, validation_image_generator]]callbacks = [tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=10, restore_best_weights=False),tf.keras.callbacks.ModelCheckpoint(weights_path, monitor='val_loss', save_best_only=True),tf.keras.callbacks.CSVLogger(csv_path)]recognizer.training_model.fit(x=training_generator,steps_per_epoch=training_steps,validation_steps=validation_steps,validation_data=validation_generator,callbacks=callbacks,epochs=1000,)
"""This script is what was used to generate thebackgrounds.zip and fonts.zip files."""# pylint: disable=invalid-name,redefined-outer-nameimport jsonimport urllib.requestimport urllib.parseimport concurrentimport shutilimport zipfileimport globimport osimport numpy as npimport tqdmimport cv2import keras_ocrif __name__ == '__main__':fonts_commit = 'a0726002eab4639ee96056a38cd35f6188011a81'fonts_sha256 = 'e447d23d24a5bbe8488200a058cd5b75b2acde525421c2e74dbfb90ceafce7bf'fonts_source_zip_filepath = keras_ocr.tools.download_and_verify(url=f'https://github.com/google/fonts/archive/{fonts_commit}.zip',cache_dir='.',sha256=fonts_sha256)shutil.rmtree('fonts-raw', ignore_errors=True)with zipfile.ZipFile(fonts_source_zip_filepath) as zfile:zfile.extractall(path='fonts-raw')retained_fonts = []sha256s = []basenames = []# The blacklist includes fonts that, at least for the English alphabet, were found# to be illegible (e.g., thin fonts) or render in unexpected ways (e.g., mathematics# fonts).blacklist = ['AlmendraDisplay-Regular.ttf', 'RedactedScript-Bold.ttf', 'RedactedScript-Regular.ttf','Sevillana-Regular.ttf', 'Mplus1p-Thin.ttf', 'Stalemate-Regular.ttf', 'jsMath-cmsy10.ttf','Codystar-Regular.ttf', 'AdventPro-Thin.ttf', 'RoundedMplus1c-Thin.ttf','EncodeSans-Thin.ttf', 'AlegreyaSans-ThinItalic.ttf', 'AlegreyaSans-Thin.ttf','FiraSans-Thin.ttf', 'FiraSans-ThinItalic.ttf', 'WorkSans-Thin.ttf','Tomorrow-ThinItalic.ttf', 'Tomorrow-Thin.ttf', 'Italianno-Regular.ttf','IBMPlexSansCondensed-Thin.ttf', 'IBMPlexSansCondensed-ThinItalic.ttf','Lato-ExtraLightItalic.ttf', 'LibreBarcode128Text-Regular.ttf','LibreBarcode39-Regular.ttf', 'LibreBarcode39ExtendedText-Regular.ttf','EncodeSansExpanded-ExtraLight.ttf', 'Exo-Thin.ttf', 'Exo-ThinItalic.ttf','DrSugiyama-Regular.ttf', 'Taviraj-ThinItalic.ttf', 'SixCaps.ttf', 'IBMPlexSans-Thin.ttf','IBMPlexSans-ThinItalic.ttf', 'AdobeBlank-Regular.ttf','FiraSansExtraCondensed-ThinItalic.ttf', 'HeptaSlab[wght].ttf', 'Karla-Italic[wght].ttf','Karla[wght].ttf', 'RalewayDots-Regular.ttf', 'FiraSansCondensed-ThinItalic.ttf','jsMath-cmex10.ttf', 'LibreBarcode39Text-Regular.ttf', 'LibreBarcode39Extended-Regular.ttf','EricaOne-Regular.ttf', 'ArimaMadurai-Thin.ttf', 'IBMPlexSerif-ExtraLight.ttf','IBMPlexSerif-ExtraLightItalic.ttf', 'IBMPlexSerif-ThinItalic.ttf', 'IBMPlexSerif-Thin.ttf','Exo2-Thin.ttf', 'Exo2-ThinItalic.ttf', 'BungeeOutline-Regular.ttf', 'Redacted-Regular.ttf','JosefinSlab-ThinItalic.ttf', 'GothicA1-Thin.ttf', 'Kanit-ThinItalic.ttf', 'Kanit-Thin.ttf','AlegreyaSansSC-ThinItalic.ttf', 'AlegreyaSansSC-Thin.ttf', 'Chathura-Thin.ttf','Blinker-Thin.ttf', 'Italiana-Regular.ttf', 'Miama-Regular.ttf', 'Grenze-ThinItalic.ttf','LeagueScript-Regular.ttf', 'BigShouldersDisplay-Thin.ttf', 'YanoneKaffeesatz[wght].ttf','BungeeHairline-Regular.ttf', 'JosefinSans-Thin.ttf', 'JosefinSans-ThinItalic.ttf','Monofett.ttf', 'Raleway-ThinItalic.ttf', 'Raleway-Thin.ttf', 'JosefinSansStd-Light.ttf','LibreBarcode128-Regular.ttf']for filepath in tqdm.tqdm(sorted(glob.glob('fonts-raw/**/**/**/*.ttf')),desc='Filtering fonts.'):sha256 = keras_ocr.tools.sha256sum(filepath)basename = os.path.basename(filepath)# We check the sha256 and filenames because some of the fonts# in the repository are duplicated (see TRIVIA.md).if sha256 in sha256s or basename in basenames or basename in blacklist:continuesha256s.append(sha256)basenames.append(basename)retained_fonts.append(filepath)retained_font_families = set([filepath.split(os.sep)[-2] for filepath in retained_fonts])added = []with zipfile.ZipFile(file='fonts.zip', mode='w') as zfile:for font_family in tqdm.tqdm(retained_font_families, desc='Saving ZIP file.'):# We want to keep all the metadata files plus# the retained font files. And we don't want# to add the same file twice.files = [input_filepath for input_filepath in glob.glob(f'fonts-raw/**/**/{font_family}/*')if input_filepath not in added and(input_filepath in retained_fonts or os.path.splitext(input_filepath)[1] != '.ttf')]added.extend(files)for input_filepath in files:zfile.write(filename=input_filepath,arcname=os.path.join(*input_filepath.split(os.sep)[-2:]))print('Finished saving fonts file.')# pylint: disable=line-too-longurl = ('https://commons.wikimedia.org/w/api.php" />创作不易 觉得有帮助请点赞关注收藏~~~